diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index fa8fa0cdee9..c6d23541a67 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -122,7 +122,7 @@ jobs: python xarray/util/print_versions.py - name: Install mypy run: | - python -m pip install "mypy" --force-reinstall + python -m pip install "mypy==1.11.2" --force-reinstall - name: Run mypy run: | @@ -176,7 +176,7 @@ jobs: python xarray/util/print_versions.py - name: Install mypy run: | - python -m pip install "mypy" --force-reinstall + python -m pip install "mypy==1.11.2" --force-reinstall - name: Run mypy run: | diff --git a/doc/api.rst b/doc/api.rst index add1eed8944..40e05035d11 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -687,7 +687,7 @@ For manipulating, traversing, navigating, or mapping over the tree structure. DataTree.relative_to DataTree.iter_lineage DataTree.find_common_ancestor - DataTree.map_over_subtree + DataTree.map_over_datasets DataTree.pipe DataTree.match DataTree.filter @@ -828,30 +828,26 @@ Index into all nodes in the subtree simultaneously. .. DataTree.polyfit .. DataTree.curvefit -.. Aggregation -.. ----------- +Aggregation +----------- -.. Aggregate data in all nodes in the subtree simultaneously. +Aggregate data in all nodes in the subtree simultaneously. -.. .. autosummary:: -.. :toctree: generated/ +.. autosummary:: + :toctree: generated/ -.. DataTree.all -.. DataTree.any -.. DataTree.argmax -.. DataTree.argmin -.. DataTree.idxmax -.. DataTree.idxmin -.. DataTree.max -.. DataTree.min -.. DataTree.mean -.. DataTree.median -.. DataTree.prod -.. DataTree.sum -.. DataTree.std -.. DataTree.var -.. DataTree.cumsum -.. DataTree.cumprod + DataTree.all + DataTree.any + DataTree.max + DataTree.min + DataTree.mean + DataTree.median + DataTree.prod + DataTree.sum + DataTree.std + DataTree.var + DataTree.cumsum + DataTree.cumprod .. ndarray methods .. --------------- @@ -958,7 +954,7 @@ DataTree methods open_datatree open_groups - map_over_subtree + map_over_datasets DataTree.to_dict DataTree.to_netcdf DataTree.to_zarr diff --git a/doc/getting-started-guide/quick-overview.rst b/doc/getting-started-guide/quick-overview.rst index fbe81b2e895..a0c6efbf751 100644 --- a/doc/getting-started-guide/quick-overview.rst +++ b/doc/getting-started-guide/quick-overview.rst @@ -307,11 +307,11 @@ We can get a copy of the :py:class:`~xarray.Dataset` including the inherited coo ds_inherited = dt["simulation/coarse"].to_dataset() ds_inherited -And you can get a copy of just the node local values of :py:class:`~xarray.Dataset` by setting the ``inherited`` keyword to ``False``: +And you can get a copy of just the node local values of :py:class:`~xarray.Dataset` by setting the ``inherit`` keyword to ``False``: .. ipython:: python - ds_node_local = dt["simulation/coarse"].to_dataset(inherited=False) + ds_node_local = dt["simulation/coarse"].to_dataset(inherit=False) ds_node_local .. note:: diff --git a/doc/user-guide/data-structures.rst b/doc/user-guide/data-structures.rst index b5e83789806..e5e89b0fbbd 100644 --- a/doc/user-guide/data-structures.rst +++ b/doc/user-guide/data-structures.rst @@ -771,7 +771,7 @@ Here there are four different coordinate variables, which apply to variables in ``station`` is used only for ``weather`` variables ``lat`` and ``lon`` are only use for ``satellite`` images -Coordinate variables are inherited to descendent nodes, which means that +Coordinate variables are inherited to descendent nodes, which is only possible because variables at different levels of a hierarchical DataTree are always aligned. Placing the ``time`` variable at the root node automatically indicates that it applies to all descendent nodes. Similarly, ``station`` is in the base @@ -792,14 +792,15 @@ automatically includes coordinates from higher levels (e.g., ``time`` and dt2["/weather/temperature"].dataset Similarly, when you retrieve a Dataset through :py:func:`~xarray.DataTree.to_dataset` , the inherited coordinates are -included by default unless you exclude them with the ``inherited`` flag: +included by default unless you exclude them with the ``inherit`` flag: .. ipython:: python dt2["/weather/temperature"].to_dataset() - dt2["/weather/temperature"].to_dataset(inherited=False) + dt2["/weather/temperature"].to_dataset(inherit=False) +For more examples and further discussion see :ref:`alignment and coordinate inheritance `. .. _coordinates: diff --git a/doc/user-guide/hierarchical-data.rst b/doc/user-guide/hierarchical-data.rst index 84016348676..cdd30f9a302 100644 --- a/doc/user-guide/hierarchical-data.rst +++ b/doc/user-guide/hierarchical-data.rst @@ -1,7 +1,7 @@ -.. _hierarchical-data: +.. _userguide.hierarchical-data: Hierarchical data -============================== +================= .. ipython:: python :suppress: @@ -15,6 +15,8 @@ Hierarchical data %xmode minimal +.. _why: + Why Hierarchical Data? ---------------------- @@ -547,13 +549,13 @@ See that the same change (fast-forwarding by adding 10 years to the age of each Mapping Custom Functions Over Trees ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -You can map custom computation over each node in a tree using :py:meth:`xarray.DataTree.map_over_subtree`. +You can map custom computation over each node in a tree using :py:meth:`xarray.DataTree.map_over_datasets`. You can map any function, so long as it takes :py:class:`xarray.Dataset` objects as one (or more) of the input arguments, and returns one (or more) xarray datasets. .. note:: - Functions passed to :py:func:`~xarray.DataTree.map_over_subtree` cannot alter nodes in-place. + Functions passed to :py:func:`~xarray.DataTree.map_over_datasets` cannot alter nodes in-place. Instead they must return new :py:class:`xarray.Dataset` objects. For example, we can define a function to calculate the Root Mean Square of a timeseries @@ -567,11 +569,11 @@ Then calculate the RMS value of these signals: .. ipython:: python - voltages.map_over_subtree(rms) + voltages.map_over_datasets(rms) .. _multiple trees: -We can also use the :py:meth:`~xarray.map_over_subtree` decorator to promote a function which accepts datasets into one which +We can also use the :py:meth:`~xarray.map_over_datasets` decorator to promote a function which accepts datasets into one which accepts datatrees. Operating on Multiple Trees @@ -644,3 +646,148 @@ We could use this feature to quickly calculate the electrical power in our signa power = currents * voltages power + +.. _hierarchical-data.alignment-and-coordinate-inheritance: + +Alignment and Coordinate Inheritance +------------------------------------ + +.. _data-alignment: + +Data Alignment +~~~~~~~~~~~~~~ + +The data in different datatree nodes are not totally independent. In particular dimensions (and indexes) in child nodes must be exactly aligned with those in their parent nodes. +Exact aligment means that shared dimensions must be the same length, and indexes along those dimensions must be equal. + +.. note:: + If you were a previous user of the prototype `xarray-contrib/datatree `_ package, this is different from what you're used to! + In that package the data model was that the data stored in each node actually was completely unrelated. The data model is now slightly stricter. + This allows us to provide features like :ref:`coordinate-inheritance`. + +To demonstrate, let's first generate some example datasets which are not aligned with one another: + +.. ipython:: python + + # (drop the attributes just to make the printed representation shorter) + ds = xr.tutorial.open_dataset("air_temperature").drop_attrs() + + ds_daily = ds.resample(time="D").mean("time") + ds_weekly = ds.resample(time="W").mean("time") + ds_monthly = ds.resample(time="ME").mean("time") + +These datasets have different lengths along the ``time`` dimension, and are therefore not aligned along that dimension. + +.. ipython:: python + + ds_daily.sizes + ds_weekly.sizes + ds_monthly.sizes + +We cannot store these non-alignable variables on a single :py:class:`~xarray.Dataset` object, because they do not exactly align: + +.. ipython:: python + :okexcept: + + xr.align(ds_daily, ds_weekly, ds_monthly, join="exact") + +But we :ref:`previously said ` that multi-resolution data is a good use case for :py:class:`~xarray.DataTree`, so surely we should be able to store these in a single :py:class:`~xarray.DataTree`? +If we first try to create a :py:class:`~xarray.DataTree` with these different-length time dimensions present in both parents and children, we will still get an alignment error: + +.. ipython:: python + :okexcept: + + xr.DataTree.from_dict({"daily": ds_daily, "daily/weekly": ds_weekly}) + +This is because DataTree checks that data in child nodes align exactly with their parents. + +.. note:: + This requirement of aligned dimensions is similar to netCDF's concept of `inherited dimensions `_, as in netCDF-4 files dimensions are `visible to all child groups `_. + +This alignment check is performed up through the tree, all the way to the root, and so is therefore equivalent to requiring that this :py:func:`~xarray.align` command succeeds: + +.. code:: python + + xr.align(child.dataset, *(parent.dataset for parent in child.parents), join="exact") + +To represent our unalignable data in a single :py:class:`~xarray.DataTree`, we must instead place all variables which are a function of these different-length dimensions into nodes that are not direct descendents of one another, e.g. organize them as siblings. + +.. ipython:: python + + dt = xr.DataTree.from_dict( + {"daily": ds_daily, "weekly": ds_weekly, "monthly": ds_monthly} + ) + dt + +Now we have a valid :py:class:`~xarray.DataTree` structure which contains all the data at each different time frequency, stored in a separate group. + +This is a useful way to organise our data because we can still operate on all the groups at once. +For example we can extract all three timeseries at a specific lat-lon location: + +.. ipython:: python + + dt.sel(lat=75, lon=300) + +or compute the standard deviation of each timeseries to find out how it varies with sampling frequency: + +.. ipython:: python + + dt.std(dim="time") + +.. _coordinate-inheritance: + +Coordinate Inheritance +~~~~~~~~~~~~~~~~~~~~~~ + +Notice that in the trees we constructed above there is some redundancy - the ``lat`` and ``lon`` variables appear in each sibling group, but are identical across the groups. + +.. ipython:: python + + dt + +We can use "Coordinate Inheritance" to define them only once in a parent group and remove this redundancy, whilst still being able to access those coordinate variables from the child groups. + +.. note:: + This is also a new feature relative to the prototype `xarray-contrib/datatree `_ package. + +Let's instead place only the time-dependent variables in the child groups, and put the non-time-dependent ``lat`` and ``lon`` variables in the parent (root) group: + +.. ipython:: python + + dt = xr.DataTree.from_dict( + { + "/": ds.drop_dims("time"), + "daily": ds_daily.drop_vars(["lat", "lon"]), + "weekly": ds_weekly.drop_vars(["lat", "lon"]), + "monthly": ds_monthly.drop_vars(["lat", "lon"]), + } + ) + dt + +This is preferred to the previous representation because it now makes it clear that all of these datasets share common spatial grid coordinates. +Defining the common coordinates just once also ensures that the spatial coordinates for each group cannot become out of sync with one another during operations. + +We can still access the coordinates defined in the parent groups from any of the child groups as if they were actually present on the child groups: + +.. ipython:: python + + dt.daily.coords + dt["daily/lat"] + +As we can still access them, we say that the ``lat`` and ``lon`` coordinates in the child groups have been "inherited" from their common parent group. + +If we print just one of the child nodes, it will still display inherited coordinates, but explicitly mark them as such: + +.. ipython:: python + + print(dt["/daily"]) + +This helps to differentiate which variables are defined on the datatree node that you are currently looking at, and which were defined somewhere above it. + +We can also still perform all the same operations on the whole tree: + +.. ipython:: python + + dt.sel(lat=[75], lon=[300]) + + dt.std(dim="time") diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b374721c8ee..deb8cc9bdc3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,7 +23,7 @@ New Features ~~~~~~~~~~~~ - ``DataTree`` related functionality is now exposed in the main ``xarray`` public API. This includes: ``xarray.DataTree``, ``xarray.open_datatree``, ``xarray.open_groups``, - ``xarray.map_over_subtree``, ``xarray.register_datatree_accessor`` and + ``xarray.map_over_datasets``, ``xarray.register_datatree_accessor`` and ``xarray.testing.assert_isomorphic``. By `Owen Littlejohns `_, `Eni Awowale `_, diff --git a/xarray/__init__.py b/xarray/__init__.py index e3b7ec469e9..1e1bfe9a770 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -34,7 +34,7 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset from xarray.core.datatree import DataTree -from xarray.core.datatree_mapping import TreeIsomorphismError, map_over_subtree +from xarray.core.datatree_mapping import TreeIsomorphismError, map_over_datasets from xarray.core.extensions import ( register_dataarray_accessor, register_dataset_accessor, @@ -86,7 +86,7 @@ "load_dataarray", "load_dataset", "map_blocks", - "map_over_subtree", + "map_over_datasets", "merge", "ones_like", "open_dataarray", diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index f1ab57badb2..29a9d3d18f5 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -3,7 +3,7 @@ import functools import io import os -from collections.abc import Callable, Iterable +from collections.abc import Iterable from typing import TYPE_CHECKING, Any import numpy as np @@ -465,7 +465,7 @@ def open_datatree( use_cftime=None, decode_timedelta=None, format=None, - group: str | Iterable[str] | Callable | None = None, + group: str | None = None, lock=None, invalid_netcdf=None, phony_dims=None, @@ -511,7 +511,7 @@ def open_groups_as_dict( use_cftime=None, decode_timedelta=None, format=None, - group: str | Iterable[str] | Callable | None = None, + group: str | None = None, lock=None, invalid_netcdf=None, phony_dims=None, diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 99ca263fa67..4b6c5e16334 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -3,7 +3,7 @@ import functools import operator import os -from collections.abc import Callable, Iterable +from collections.abc import Iterable from contextlib import suppress from typing import TYPE_CHECKING, Any @@ -700,7 +700,7 @@ def open_datatree( drop_variables: str | Iterable[str] | None = None, use_cftime=None, decode_timedelta=None, - group: str | Iterable[str] | Callable | None = None, + group: str | None = None, format="NETCDF4", clobber=True, diskless=False, @@ -745,7 +745,7 @@ def open_groups_as_dict( drop_variables: str | Iterable[str] | None = None, use_cftime=None, decode_timedelta=None, - group: str | Iterable[str] | Callable | None = None, + group: str | None = None, format="NETCDF4", clobber=True, diskless=False, diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index def932bde4a..107214f0476 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -3,7 +3,7 @@ import json import os import warnings -from collections.abc import Callable, Iterable +from collections.abc import Iterable from typing import TYPE_CHECKING, Any import numpy as np @@ -1282,7 +1282,7 @@ def open_datatree( drop_variables: str | Iterable[str] | None = None, use_cftime=None, decode_timedelta=None, - group: str | Iterable[str] | Callable | None = None, + group: str | None = None, mode="r", synchronizer=None, consolidated=None, @@ -1328,7 +1328,7 @@ def open_groups_as_dict( drop_variables: str | Iterable[str] | None = None, use_cftime=None, decode_timedelta=None, - group: str | Iterable[str] | Callable | None = None, + group: str | None = None, mode="r", synchronizer=None, consolidated=None, diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index b557ad44a32..6b1029791ea 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -19,6 +19,1314 @@ flox_available = module_available("flox") +class DataTreeAggregations: + __slots__ = () + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + **kwargs: Any, + ) -> Self: + raise NotImplementedError() + + def count( + self, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``count`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``count``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``count`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``count`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + pandas.DataFrame.count + dask.dataframe.DataFrame.count + Dataset.count + DataArray.count + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.count() + + Group: / + Dimensions: () + Data variables: + foo int64 8B 5 + """ + return self.reduce( + duck_array_ops.count, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def all( + self, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``all`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``all``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``all`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``all`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.all + dask.array.all + Dataset.all + DataArray.all + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict( + ... foo=( + ... "time", + ... np.array([True, True, True, True, True, False], dtype=bool), + ... ) + ... ), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.all() + + Group: / + Dimensions: () + Data variables: + foo bool 1B False + """ + return self.reduce( + duck_array_ops.array_all, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def any( + self, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``any`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``any``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``any`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``any`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.any + dask.array.any + Dataset.any + DataArray.any + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict( + ... foo=( + ... "time", + ... np.array([True, True, True, True, True, False], dtype=bool), + ... ) + ... ), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.any() + + Group: / + Dimensions: () + Data variables: + foo bool 1B True + """ + return self.reduce( + duck_array_ops.array_any, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def max( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``max`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``max``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``max`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``max`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.max + dask.array.max + Dataset.max + DataArray.max + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.max() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 3.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.max(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + """ + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def min( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``min`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``min``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``min`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``min`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.min + dask.array.min + Dataset.min + DataArray.min + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.min() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 0.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.min(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + """ + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def mean( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``mean`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``mean``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``mean`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``mean`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.mean + dask.array.mean + Dataset.mean + DataArray.mean + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.mean() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.6 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.mean(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + """ + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def prod( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + min_count: int | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``prod`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``prod``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int or None, optional + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``prod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``prod`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.prod + dask.array.prod + Dataset.prod + DataArray.prod + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.prod() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 0.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.prod(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> dt.prod(skipna=True, min_count=2) + + Group: / + Dimensions: () + Data variables: + foo float64 8B 0.0 + """ + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def sum( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + min_count: int | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``sum`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``sum``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int or None, optional + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``sum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``sum`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.sum + dask.array.sum + Dataset.sum + DataArray.sum + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.sum() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 8.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.sum(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> dt.sum(skipna=True, min_count=2) + + Group: / + Dimensions: () + Data variables: + foo float64 8B 8.0 + """ + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def std( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + ddof: int = 0, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``std`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``std``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``std`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``std`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.std + dask.array.std + Dataset.std + DataArray.std + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.std() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.02 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.std(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + + Specify ``ddof=1`` for an unbiased estimate. + + >>> dt.std(skipna=True, ddof=1) + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.14 + """ + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def var( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + ddof: int = 0, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``var`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``var``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``var`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``var`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.var + dask.array.var + Dataset.var + DataArray.var + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.var() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.04 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.var(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + + Specify ``ddof=1`` for an unbiased estimate. + + >>> dt.var(skipna=True, ddof=1) + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.3 + """ + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def median( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``median`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``median``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``median`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``median`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.median + dask.array.median + Dataset.median + DataArray.median + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.median() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 2.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.median(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + """ + return self.reduce( + duck_array_ops.median, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def cumsum( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``cumsum`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``cumsum``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``cumsum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``cumsum`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.cumsum + dask.array.cumsum + Dataset.cumsum + DataArray.cumsum + DataTree.cumulative + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Note that the methods on the ``cumulative`` method are more performant (with numbagg installed) + and better supported. ``cumsum`` and ``cumprod`` may be deprecated + in the future. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.cumsum() + + Group: / + Dimensions: (time: 6) + Dimensions without coordinates: time + Data variables: + foo (time) float64 48B 1.0 3.0 6.0 6.0 8.0 8.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.cumsum(skipna=False) + + Group: / + Dimensions: (time: 6) + Dimensions without coordinates: time + Data variables: + foo (time) float64 48B 1.0 3.0 6.0 6.0 8.0 nan + """ + return self.reduce( + duck_array_ops.cumsum, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def cumprod( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``cumprod`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``cumprod``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``cumprod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``cumprod`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.cumprod + dask.array.cumprod + Dataset.cumprod + DataArray.cumprod + DataTree.cumulative + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Note that the methods on the ``cumulative`` method are more performant (with numbagg installed) + and better supported. ``cumsum`` and ``cumprod`` may be deprecated + in the future. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.cumprod() + + Group: / + Dimensions: (time: 6) + Dimensions without coordinates: time + Data variables: + foo (time) float64 48B 1.0 2.0 6.0 0.0 0.0 0.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.cumprod(skipna=False) + + Group: / + Dimensions: (time: 6) + Dimensions without coordinates: time + Data variables: + foo (time) float64 48B 1.0 2.0 6.0 0.0 0.0 nan + """ + return self.reduce( + duck_array_ops.cumprod, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + class DatasetAggregations: __slots__ = () diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 91a184d55cd..e2a6676252a 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -31,7 +31,7 @@ from xarray.core.merge import merge_attrs, merge_coordinates_without_align from xarray.core.options import OPTIONS, _get_keep_attrs from xarray.core.types import Dims, T_DataArray -from xarray.core.utils import is_dict_like, is_scalar, parse_dims +from xarray.core.utils import is_dict_like, is_scalar, parse_dims_as_set from xarray.core.variable import Variable from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import is_chunked_array @@ -1841,16 +1841,15 @@ def dot( einsum_axes = "abcdefghijklmnopqrstuvwxyz" dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} + dot_dims: set[Hashable] if dim is None: # find dimensions that occur more than once dim_counts: Counter = Counter() for arr in arrays: dim_counts.update(arr.dims) - dim = tuple(d for d, c in dim_counts.items() if c > 1) + dot_dims = {d for d, c in dim_counts.items() if c > 1} else: - dim = parse_dims(dim, all_dims=tuple(all_dims)) - - dot_dims: set[Hashable] = set(dim) + dot_dims = parse_dims_as_set(dim, all_dims=set(all_dims)) # dimensions to be parallelized broadcast_dims = common_dims - dot_dims diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index a6dec863aec..91ef9b6ccad 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -849,14 +849,14 @@ def _update_coords( from xarray.core.datatree import check_alignment # create updated node (`.to_dataset` makes a copy so this doesn't modify in-place) - node_ds = self._data.to_dataset(inherited=False) + node_ds = self._data.to_dataset(inherit=False) node_ds.coords._update_coords(coords, indexes) # check consistency *before* modifying anything in-place # TODO can we clean up the signature of check_alignment to make this less awkward? if self._data.parent is not None: parent_ds = self._data.parent._to_dataset_view( - inherited=True, rebuild_dims=False + inherit=True, rebuild_dims=False ) else: parent_ds = None diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a7dedd2ed07..e0cd92bab6e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -118,6 +118,7 @@ is_duck_dask_array, is_scalar, maybe_wrap_array, + parse_dims_as_set, ) from xarray.core.variable import ( IndexVariable, @@ -6986,18 +6987,7 @@ def reduce( " Please use 'dim' instead." ) - if dim is None or dim is ...: - dims = set(self.dims) - elif isinstance(dim, str) or not isinstance(dim, Iterable): - dims = {dim} - else: - dims = set(dim) - - missing_dimensions = tuple(d for d in dims if d not in self.dims) - if missing_dimensions: - raise ValueError( - f"Dimensions {missing_dimensions} not found in data dimensions {tuple(self.dims)}" - ) + dims = parse_dims_as_set(dim, set(self._dims.keys())) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 3b82f7a58d2..e503b5c0741 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, Literal, NoReturn, Union, overload from xarray.core import utils +from xarray.core._aggregations import DataTreeAggregations from xarray.core.alignment import align from xarray.core.common import TreeAttrAccessMixin from xarray.core.coordinates import Coordinates, DataTreeCoordinates @@ -22,7 +23,7 @@ from xarray.core.datatree_mapping import ( TreeIsomorphismError, check_isomorphic, - map_over_subtree, + map_over_datasets, ) from xarray.core.formatting import datatree_repr, dims_and_coords_repr from xarray.core.formatting_html import ( @@ -41,6 +42,7 @@ drop_dims_from_indexers, either_dict_or_kwargs, maybe_wrap_array, + parse_dims_as_set, ) from xarray.core.variable import Variable @@ -57,6 +59,7 @@ from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes from xarray.core.merge import CoercibleMapping, CoercibleValue from xarray.core.types import ( + Dims, ErrorOptions, ErrorOptionsWithWarn, NetcdfWriteModes, @@ -154,7 +157,7 @@ def check_alignment( for child_name, child in children.items(): child_path = str(NodePath(path) / child_name) - child_ds = child.to_dataset(inherited=False) + child_ds = child.to_dataset(inherit=False) check_alignment(child_path, child_ds, base_ds, child.children) @@ -251,14 +254,14 @@ def _constructor( def __setitem__(self, key, val) -> None: raise AttributeError( "Mutation of the DatasetView is not allowed, please use `.__setitem__` on the wrapping DataTree node, " - "or use `dt.to_dataset()` if you want a mutable dataset. If calling this from within `map_over_subtree`," + "or use `dt.to_dataset()` if you want a mutable dataset. If calling this from within `map_over_datasets`," "use `.copy()` first to get a mutable version of the input dataset." ) def update(self, other) -> NoReturn: raise AttributeError( "Mutation of the DatasetView is not allowed, please use `.update` on the wrapping DataTree node, " - "or use `dt.to_dataset()` if you want a mutable dataset. If calling this from within `map_over_subtree`," + "or use `dt.to_dataset()` if you want a mutable dataset. If calling this from within `map_over_datasets`," "use `.copy()` first to get a mutable version of the input dataset." ) @@ -399,6 +402,7 @@ def map( # type: ignore[override] class DataTree( NamedNode["DataTree"], + DataTreeAggregations, TreeAttrAccessMixin, Mapping[str, "DataArray | DataTree"], ): @@ -503,8 +507,8 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: f"parent {parent.name} already contains a variable named {name}" ) path = str(NodePath(parent.path) / name) - node_ds = self.to_dataset(inherited=False) - parent_ds = parent._to_dataset_view(rebuild_dims=False, inherited=True) + node_ds = self.to_dataset(inherit=False) + parent_ds = parent._to_dataset_view(rebuild_dims=False, inherit=True) check_alignment(path, node_ds, parent_ds, self.children) _deduplicate_inherited_coordinates(self, parent) @@ -529,14 +533,14 @@ def _dims(self) -> ChainMap[Hashable, int]: def _indexes(self) -> ChainMap[Hashable, Index]: return ChainMap(self._node_indexes, *(p._node_indexes for p in self.parents)) - def _to_dataset_view(self, rebuild_dims: bool, inherited: bool) -> DatasetView: - coord_vars = self._coord_variables if inherited else self._node_coord_variables + def _to_dataset_view(self, rebuild_dims: bool, inherit: bool) -> DatasetView: + coord_vars = self._coord_variables if inherit else self._node_coord_variables variables = dict(self._data_variables) variables |= coord_vars if rebuild_dims: dims = calculate_dimensions(variables) - elif inherited: - # Note: rebuild_dims=False with inherited=True can create + elif inherit: + # Note: rebuild_dims=False with inherit=True can create # technically invalid Dataset objects because it still includes # dimensions that are only defined on parent data variables # (i.e. not present on any parent coordinate variables). @@ -548,7 +552,7 @@ def _to_dataset_view(self, rebuild_dims: bool, inherited: bool) -> DatasetView: # ... "/b": xr.Dataset(), # ... } # ... ) - # >>> ds = tree["b"]._to_dataset_view(rebuild_dims=False, inherited=True) + # >>> ds = tree["b"]._to_dataset_view(rebuild_dims=False, inherit=True) # >>> ds # Size: 0B # Dimensions: (x: 2) @@ -572,7 +576,7 @@ def _to_dataset_view(self, rebuild_dims: bool, inherited: bool) -> DatasetView: coord_names=set(self._coord_variables), dims=dims, attrs=self._attrs, - indexes=dict(self._indexes if inherited else self._node_indexes), + indexes=dict(self._indexes if inherit else self._node_indexes), encoding=self._encoding, close=None, ) @@ -591,7 +595,7 @@ def dataset(self) -> DatasetView: -------- DataTree.to_dataset """ - return self._to_dataset_view(rebuild_dims=True, inherited=True) + return self._to_dataset_view(rebuild_dims=True, inherit=True) @dataset.setter def dataset(self, data: Dataset | None = None) -> None: @@ -602,13 +606,13 @@ def dataset(self, data: Dataset | None = None) -> None: # xarray-contrib/datatree ds = dataset - def to_dataset(self, inherited: bool = True) -> Dataset: + def to_dataset(self, inherit: bool = True) -> Dataset: """ Return the data in this node as a new xarray.Dataset object. Parameters ---------- - inherited : bool, optional + inherit : bool, optional If False, only include coordinates and indexes defined at the level of this DataTree node, excluding any inherited coordinates and indexes. @@ -616,16 +620,16 @@ def to_dataset(self, inherited: bool = True) -> Dataset: -------- DataTree.dataset """ - coord_vars = self._coord_variables if inherited else self._node_coord_variables + coord_vars = self._coord_variables if inherit else self._node_coord_variables variables = dict(self._data_variables) variables |= coord_vars - dims = calculate_dimensions(variables) if inherited else dict(self._node_dims) + dims = calculate_dimensions(variables) if inherit else dict(self._node_dims) return Dataset._construct_direct( variables, set(coord_vars), dims, None if self._attrs is None else dict(self._attrs), - dict(self._indexes if inherited else self._node_indexes), + dict(self._indexes if inherit else self._node_indexes), None if self._encoding is None else dict(self._encoding), self._close, ) @@ -794,7 +798,7 @@ def _replace_node( children: dict[str, DataTree] | Default = _default, ) -> None: - ds = self.to_dataset(inherited=False) if data is _default else data + ds = self.to_dataset(inherit=False) if data is _default else data if children is _default: children = self._children @@ -804,7 +808,7 @@ def _replace_node( raise ValueError(f"node already contains a variable named {child_name}") parent_ds = ( - self.parent._to_dataset_view(rebuild_dims=False, inherited=True) + self.parent._to_dataset_view(rebuild_dims=False, inherit=True) if self.parent is not None else None ) @@ -813,6 +817,9 @@ def _replace_node( if data is not _default: self._set_node_data(ds) + if self.parent is not None: + _deduplicate_inherited_coordinates(self, self.parent) + self.children = children def _copy_node( @@ -823,7 +830,7 @@ def _copy_node( new_node = super()._copy_node() - data = self._to_dataset_view(rebuild_dims=False, inherited=False) + data = self._to_dataset_view(rebuild_dims=False, inherit=False) if deep: data = data.copy(deep=True) new_node._set_node_data(data) @@ -993,7 +1000,7 @@ def update( raise TypeError(f"Type {type(v)} cannot be assigned to a DataTree") vars_merge_result = dataset_update_method( - self.to_dataset(inherited=False), new_variables + self.to_dataset(inherit=False), new_variables ) data = Dataset._construct_direct(**vars_merge_result._asdict()) @@ -1323,7 +1330,7 @@ def filter(self: DataTree, filterfunc: Callable[[DataTree], bool]) -> DataTree: -------- match pipe - map_over_subtree + map_over_datasets """ filtered_nodes = { node.path: node.dataset for node in self.subtree if filterfunc(node) @@ -1349,7 +1356,7 @@ def match(self, pattern: str) -> DataTree: -------- filter pipe - map_over_subtree + map_over_datasets Examples -------- @@ -1376,7 +1383,7 @@ def match(self, pattern: str) -> DataTree: } return DataTree.from_dict(matching_nodes, name=self.root.name) - def map_over_subtree( + def map_over_datasets( self, func: Callable, *args: Iterable[Any], @@ -1410,7 +1417,7 @@ def map_over_subtree( # TODO this signature means that func has no way to know which node it is being called upon - change? # TODO fix this typing error - return map_over_subtree(func)(self, *args, **kwargs) + return map_over_datasets(func)(self, *args, **kwargs) def pipe( self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any @@ -1609,6 +1616,38 @@ def to_zarr( **kwargs, ) + def _get_all_dims(self) -> set: + all_dims = set() + for node in self.subtree: + all_dims.update(node._node_dims) + return all_dims + + def reduce( + self, + func: Callable, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + keepdims: bool = False, + numeric_only: bool = False, + **kwargs: Any, + ) -> Self: + """Reduce this tree by applying `func` along some dimension(s).""" + dims = parse_dims_as_set(dim, self._get_all_dims()) + result = {} + for node in self.subtree: + reduce_dims = [d for d in node._node_dims if d in dims] + node_result = node.dataset.reduce( + func, + reduce_dims, + keep_attrs=keep_attrs, + keepdims=keepdims, + numeric_only=numeric_only, + **kwargs, + ) + result[node.path] = node_result + return type(self).from_dict(result, name=self.name) + def _selective_indexing( self, func: Callable[[Dataset, Mapping[Any, Any]], Dataset], @@ -1619,9 +1658,7 @@ def _selective_indexing( dimensions and inherited coordinates gracefully by only applying indexing at each node selectively. """ - all_dims = set() - for node in self.subtree: - all_dims.update(node._node_dims) + all_dims = self._get_all_dims() indexers = drop_dims_from_indexers(indexers, all_dims, missing_dims) result = {} diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py index 36665a0d153..908d0697525 100644 --- a/xarray/core/datatree_io.py +++ b/xarray/core/datatree_io.py @@ -85,7 +85,7 @@ def _datatree_to_netcdf( unlimited_dims = {} for node in dt.subtree: - ds = node.to_dataset(inherited=False) + ds = node.to_dataset(inherit=False) group_path = node.path if ds is None: _create_empty_netcdf_group(filepath, group_path, mode, engine) @@ -151,7 +151,7 @@ def _datatree_to_zarr( ) for node in dt.subtree: - ds = node.to_dataset(inherited=False) + ds = node.to_dataset(inherit=False) group_path = node.path if ds is None: _create_empty_zarr_group(store, group_path, mode) diff --git a/xarray/core/datatree_mapping.py b/xarray/core/datatree_mapping.py index f3e7ce348b1..2817effa856 100644 --- a/xarray/core/datatree_mapping.py +++ b/xarray/core/datatree_mapping.py @@ -75,7 +75,7 @@ def check_isomorphic( raise TreeIsomorphismError("DataTree objects are not isomorphic:\n" + diff) -def map_over_subtree(func: Callable) -> Callable: +def map_over_datasets(func: Callable) -> Callable: """ Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees. @@ -115,8 +115,8 @@ def map_over_subtree(func: Callable) -> Callable: See also -------- - DataTree.map_over_subtree - DataTree.map_over_subtree_inplace + DataTree.map_over_datasets + DataTree.map_over_datasets_inplace DataTree.subtree """ @@ -125,7 +125,7 @@ def map_over_subtree(func: Callable) -> Callable: # TODO inspect function to work out immediately if the wrong number of arguments were passed for it? @functools.wraps(func) - def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]: + def _map_over_datasets(*args, **kwargs) -> DataTree | tuple[DataTree, ...]: """Internal function which maps func over every node in tree, returning a tree of the results.""" from xarray.core.datatree import DataTree @@ -230,7 +230,7 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]: else: return tuple(result_trees) - return _map_over_subtree + return _map_over_datasets def _handle_errors_with_path_context(path: str): diff --git a/xarray/core/datatree_ops.py b/xarray/core/datatree_ops.py index 693b0a402b9..69c1b9e9082 100644 --- a/xarray/core/datatree_ops.py +++ b/xarray/core/datatree_ops.py @@ -4,7 +4,7 @@ import textwrap from xarray.core.dataset import Dataset -from xarray.core.datatree_mapping import map_over_subtree +from xarray.core.datatree_mapping import map_over_datasets """ Module which specifies the subset of xarray.Dataset's API which we wish to copy onto DataTree. @@ -17,7 +17,7 @@ _MAPPED_DOCSTRING_ADDENDUM = ( "This method was copied from :py:class:`xarray.Dataset`, but has been altered to " "call the method on the Datasets stored in every node of the subtree. " - "See the `map_over_subtree` function for more details." + "See the `map_over_datasets` function for more details." ) # TODO equals, broadcast_equals etc. @@ -174,7 +174,7 @@ def _wrap_then_attach_to_cls( target_cls_dict, source_cls, methods_to_set, wrap_func=None ): """ - Attach given methods on a class, and optionally wrap each method first. (i.e. with map_over_subtree). + Attach given methods on a class, and optionally wrap each method first. (i.e. with map_over_datasets). Result is like having written this in the classes' definition: ``` @@ -206,7 +206,7 @@ def method_name(self, *args, **kwargs): ) target_cls_dict[method_name] = wrapped_method - if wrap_func is map_over_subtree: + if wrap_func is map_over_datasets: # Add a paragraph to the method's docstring explaining how it's been mapped orig_method_docstring = orig_method.__doc__ @@ -277,7 +277,7 @@ class MappedDatasetMethodsMixin: target_cls_dict=vars(), source_cls=Dataset, methods_to_set=_ALL_DATASET_METHODS_TO_MAP, - wrap_func=map_over_subtree, + wrap_func=map_over_datasets, ) @@ -291,7 +291,7 @@ class MappedDataWithCoords: target_cls_dict=vars(), source_cls=Dataset, methods_to_set=_DATA_WITH_COORDS_METHODS_TO_MAP, - wrap_func=map_over_subtree, + wrap_func=map_over_datasets, ) @@ -305,5 +305,5 @@ class DataTreeArithmeticMixin: target_cls_dict=vars(), source_cls=Dataset, methods_to_set=_ARITHMETIC_METHODS_TO_MAP, - wrap_func=map_over_subtree, + wrap_func=map_over_datasets, ) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 91864d4581a..5ef3b9924a0 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -1113,7 +1113,7 @@ def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str: summary.append(f"{dims_start}({dims_values})") if node._node_coord_variables: - node_coords = node.to_dataset(inherited=False).coords + node_coords = node.to_dataset(inherit=False).coords summary.append(coords_repr(node_coords, col_width=col_width, max_rows=max_rows)) if show_inherited and inherited_coords: diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index 34c7a93bd7a..e0a461caea7 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -386,7 +386,7 @@ def summarize_datatree_children(children: Mapping[str, DataTree]) -> str: def datatree_node_repr(group_title: str, dt: DataTree) -> str: header_components = [f"
{escape(group_title)}
"] - ds = dt._to_dataset_view(rebuild_dims=False, inherited=True) + ds = dt._to_dataset_view(rebuild_dims=False, inherit=True) sections = [ children_section(dt.children), diff --git a/xarray/core/utils.py b/xarray/core/utils.py index e5168342e1e..e2781366265 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -59,6 +59,7 @@ MutableMapping, MutableSet, Sequence, + Set, ValuesView, ) from enum import Enum @@ -831,7 +832,7 @@ def drop_dims_from_indexers( @overload -def parse_dims( +def parse_dims_as_tuple( dim: Dims, all_dims: tuple[Hashable, ...], *, @@ -841,7 +842,7 @@ def parse_dims( @overload -def parse_dims( +def parse_dims_as_tuple( dim: Dims, all_dims: tuple[Hashable, ...], *, @@ -850,7 +851,7 @@ def parse_dims( ) -> tuple[Hashable, ...] | None | EllipsisType: ... -def parse_dims( +def parse_dims_as_tuple( dim: Dims, all_dims: tuple[Hashable, ...], *, @@ -891,6 +892,47 @@ def parse_dims( return tuple(dim) +@overload +def parse_dims_as_set( + dim: Dims, + all_dims: set[Hashable], + *, + check_exists: bool = True, + replace_none: Literal[True] = True, +) -> set[Hashable]: ... + + +@overload +def parse_dims_as_set( + dim: Dims, + all_dims: set[Hashable], + *, + check_exists: bool = True, + replace_none: Literal[False], +) -> set[Hashable] | None | EllipsisType: ... + + +def parse_dims_as_set( + dim: Dims, + all_dims: set[Hashable], + *, + check_exists: bool = True, + replace_none: bool = True, +) -> set[Hashable] | None | EllipsisType: + """Like parse_dims_as_tuple, but returning a set instead of a tuple.""" + # TODO: Consider removing parse_dims_as_tuple? + if dim is None or dim is ...: + if replace_none: + return all_dims + return dim + if isinstance(dim, str): + dim = {dim} + dim = set(dim) + if check_exists: + _check_dims(dim, all_dims) + return dim + + @overload def parse_ordered_dims( dim: Dims, @@ -958,7 +1000,7 @@ def parse_ordered_dims( return dims[:idx] + other_dims + dims[idx + 1 :] else: # mypy cannot resolve that the sequence cannot contain "..." - return parse_dims( # type: ignore[call-overload] + return parse_dims_as_tuple( # type: ignore[call-overload] dim=dim, all_dims=all_dims, check_exists=check_exists, @@ -966,7 +1008,7 @@ def parse_ordered_dims( ) -def _check_dims(dim: set[Hashable], all_dims: set[Hashable]) -> None: +def _check_dims(dim: Set[Hashable], all_dims: Set[Hashable]) -> None: wrong_dims = (dim - all_dims) - {...} if wrong_dims: wrong_dims_str = ", ".join(f"'{d!s}'" for d in wrong_dims) diff --git a/xarray/namedarray/_aggregations.py b/xarray/namedarray/_aggregations.py index 48001422386..139cea83b5b 100644 --- a/xarray/namedarray/_aggregations.py +++ b/xarray/namedarray/_aggregations.py @@ -61,10 +61,7 @@ def count( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -116,8 +113,7 @@ def all( -------- >>> from xarray.namedarray.core import NamedArray >>> na = NamedArray( - ... "x", - ... np.array([True, True, True, True, True, False], dtype=bool), + ... "x", np.array([True, True, True, True, True, False], dtype=bool) ... ) >>> na Size: 6B @@ -170,8 +166,7 @@ def any( -------- >>> from xarray.namedarray.core import NamedArray >>> na = NamedArray( - ... "x", - ... np.array([True, True, True, True, True, False], dtype=bool), + ... "x", np.array([True, True, True, True, True, False], dtype=bool) ... ) >>> na Size: 6B @@ -230,10 +225,7 @@ def max( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -298,10 +290,7 @@ def min( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -370,10 +359,7 @@ def mean( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -449,10 +435,7 @@ def prod( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -535,10 +518,7 @@ def sum( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -618,10 +598,7 @@ def std( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -701,10 +678,7 @@ def var( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -780,10 +754,7 @@ def median( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -857,10 +828,7 @@ def cumsum( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -934,10 +902,7 @@ def cumprod( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 8ca4711acad..72e8a7464c5 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -131,7 +131,7 @@ def test_to_netcdf_inherited_coords(self, tmpdir): roundtrip_dt = open_datatree(filepath, engine=self.engine) assert_equal(original_dt, roundtrip_dt) subtree = cast(DataTree, roundtrip_dt["/sub"]) - assert "x" not in subtree.to_dataset(inherited=False).coords + assert "x" not in subtree.to_dataset(inherit=False).coords def test_netcdf_encoding(self, tmpdir, simple_datatree): filepath = tmpdir / "test.nc" @@ -320,7 +320,7 @@ def test_to_zarr_inherited_coords(self, tmpdir): roundtrip_dt = open_datatree(filepath, engine="zarr") assert_equal(original_dt, roundtrip_dt) subtree = cast(DataTree, roundtrip_dt["/sub"]) - assert "x" not in subtree.to_dataset(inherited=False).coords + assert "x" not in subtree.to_dataset(inherit=False).coords def test_open_groups_round_trip(self, tmpdir, simple_datatree) -> None: """Test `open_groups` opens a zarr store with the `simple_datatree` structure.""" diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 1178498de19..eafc11b630c 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5615,7 +5615,7 @@ def test_reduce_bad_dim(self) -> None: data = create_test_data() with pytest.raises( ValueError, - match=r"Dimensions \('bad_dim',\) not found in data dimensions", + match=re.escape("Dimension(s) 'bad_dim' do not exist"), ): data.mean(dim="bad_dim") @@ -5644,7 +5644,7 @@ def test_reduce_cumsum_test_dims(self, reduct, expected, func) -> None: data = create_test_data() with pytest.raises( ValueError, - match=r"Dimensions \('bad_dim',\) not found in data dimensions", + match=re.escape("Dimension(s) 'bad_dim' do not exist"), ): getattr(data, func)(dim="bad_dim") diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 1de53a35311..69c6566f88c 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -222,12 +222,12 @@ def test_to_dataset_inherited(self): tree = DataTree.from_dict({"/": base, "/sub": sub}) subtree = typing.cast(DataTree, tree["sub"]) - assert_identical(tree.to_dataset(inherited=False), base) - assert_identical(subtree.to_dataset(inherited=False), sub) + assert_identical(tree.to_dataset(inherit=False), base) + assert_identical(subtree.to_dataset(inherit=False), sub) sub_and_base = xr.Dataset(coords={"a": [1], "c": [3]}) # no "b" - assert_identical(tree.to_dataset(inherited=True), base) - assert_identical(subtree.to_dataset(inherited=True), sub_and_base) + assert_identical(tree.to_dataset(inherit=True), base) + assert_identical(subtree.to_dataset(inherit=True), sub_and_base) class TestVariablesChildrenNameCollisions: @@ -368,8 +368,8 @@ def test_update_inherited_coords(self): # DataTree.identical() currently does not require that non-inherited # coordinates are defined identically, so we need to check this # explicitly - actual_node = actual.children["b"].to_dataset(inherited=False) - expected_node = expected.children["b"].to_dataset(inherited=False) + actual_node = actual.children["b"].to_dataset(inherit=False) + expected_node = expected.children["b"].to_dataset(inherit=False) assert_identical(actual_node, expected_node) @@ -414,7 +414,7 @@ def test_copy_coord_inheritance(self) -> None: {"/": xr.Dataset(coords={"x": [0, 1]}), "/c": DataTree()} ) tree2 = tree.copy() - node_ds = tree2.children["c"].to_dataset(inherited=False) + node_ds = tree2.children["c"].to_dataset(inherit=False) assert_identical(node_ds, xr.Dataset()) def test_deepcopy(self, create_test_datatree): @@ -1267,8 +1267,8 @@ def test_inherited_dims(self): assert dt.c.sizes == {"x": 2, "y": 3} # dataset objects created from nodes should not assert dt.b.dataset.sizes == {"y": 1} - assert dt.b.to_dataset(inherited=True).sizes == {"y": 1} - assert dt.b.to_dataset(inherited=False).sizes == {"y": 1} + assert dt.b.to_dataset(inherit=True).sizes == {"y": 1} + assert dt.b.to_dataset(inherit=False).sizes == {"y": 1} def test_inherited_coords_index(self): dt = DataTree.from_dict( @@ -1306,15 +1306,28 @@ def test_inherited_coords_with_index_are_deduplicated(self): "/b": xr.Dataset(coords={"x": [1, 2]}), } ) - child_dataset = dt.children["b"].to_dataset(inherited=False) + child_dataset = dt.children["b"].to_dataset(inherit=False) expected = xr.Dataset() assert_identical(child_dataset, expected) dt["/c"] = xr.Dataset({"foo": ("x", [4, 5])}, coords={"x": [1, 2]}) - child_dataset = dt.children["c"].to_dataset(inherited=False) + child_dataset = dt.children["c"].to_dataset(inherit=False) expected = xr.Dataset({"foo": ("x", [4, 5])}) assert_identical(child_dataset, expected) + def test_deduplicated_after_setitem(self): + # regression test for GH #9601 + dt = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1, 2]}), + "/b": None, + } + ) + dt["b/x"] = dt["x"] + child_dataset = dt.children["b"].to_dataset(inherit=False) + expected = xr.Dataset() + assert_identical(child_dataset, expected) + def test_inconsistent_dims(self): expected_msg = _exact_match( """ @@ -1648,9 +1661,8 @@ def test_sel(self): assert_equal(actual, expected) -class TestDSMethodInheritance: +class TestAggregations: - @pytest.mark.xfail(reason="reduce methods not implemented yet") def test_reduce_method(self): ds = xr.Dataset({"a": ("x", [False, True, False])}) dt = DataTree.from_dict({"/": ds, "/results": ds}) @@ -1660,7 +1672,6 @@ def test_reduce_method(self): result = dt.any() assert_equal(result, expected) - @pytest.mark.xfail(reason="reduce methods not implemented yet") def test_nan_reduce_method(self): ds = xr.Dataset({"a": ("x", [1, 2, 3])}) dt = DataTree.from_dict({"/": ds, "/results": ds}) @@ -1670,7 +1681,6 @@ def test_nan_reduce_method(self): result = dt.mean() assert_equal(result, expected) - @pytest.mark.xfail(reason="cum methods not implemented yet") def test_cum_method(self): ds = xr.Dataset({"a": ("x", [1, 2, 3])}) dt = DataTree.from_dict({"/": ds, "/results": ds}) @@ -1685,6 +1695,41 @@ def test_cum_method(self): result = dt.cumsum() assert_equal(result, expected) + def test_dim_argument(self): + dt = DataTree.from_dict( + { + "/a": xr.Dataset({"A": ("x", [1, 2])}), + "/b": xr.Dataset({"B": ("y", [1, 2])}), + } + ) + + expected = DataTree.from_dict( + { + "/a": xr.Dataset({"A": 1.5}), + "/b": xr.Dataset({"B": 1.5}), + } + ) + actual = dt.mean() + assert_equal(expected, actual) + + actual = dt.mean(dim=...) + assert_equal(expected, actual) + + expected = DataTree.from_dict( + { + "/a": xr.Dataset({"A": 1.5}), + "/b": xr.Dataset({"B": ("y", [1.0, 2.0])}), + } + ) + actual = dt.mean("x") + assert_equal(expected, actual) + + with pytest.raises( + ValueError, + match=re.escape("Dimension(s) 'invalid' do not exist."), + ): + dt.mean("invalid") + class TestOps: @pytest.mark.xfail(reason="arithmetic not implemented yet") @@ -1741,7 +1786,7 @@ def test_arithmetic_inherited_coords(self): tree["/foo"] = DataTree(xr.Dataset({"bar": ("x", [4, 5, 6])})) actual: DataTree = 2 * tree # type: ignore[assignment,operator] - actual_dataset = actual.children["foo"].to_dataset(inherited=False) + actual_dataset = actual.children["foo"].to_dataset(inherit=False) assert "x" not in actual_dataset.coords expected = tree.copy() @@ -1760,7 +1805,7 @@ def test_tree(self, create_test_datatree): class TestDocInsertion: - """Tests map_over_subtree docstring injection.""" + """Tests map_over_datasets docstring injection.""" def test_standard_doc(self): @@ -1794,7 +1839,7 @@ def test_standard_doc(self): .. note:: This method was copied from :py:class:`xarray.Dataset`, but has been altered to call the method on the Datasets stored in every - node of the subtree. See the `map_over_subtree` function for more + node of the subtree. See the `map_over_datasets` function for more details. Normally, it should not be necessary to call this method in user code, @@ -1825,7 +1870,7 @@ def test_one_liner(self): This method was copied from :py:class:`xarray.Dataset`, but has been altered to call the method on the Datasets stored in every node of the subtree. See - the `map_over_subtree` function for more details.""" + the `map_over_datasets` function for more details.""" ) actual_doc = insert_doc_addendum(mixin_doc, _MAPPED_DOCSTRING_ADDENDUM) diff --git a/xarray/tests/test_datatree_mapping.py b/xarray/tests/test_datatree_mapping.py index 9a7d3009c3b..766df76a259 100644 --- a/xarray/tests/test_datatree_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -5,7 +5,7 @@ from xarray.core.datatree_mapping import ( TreeIsomorphismError, check_isomorphic, - map_over_subtree, + map_over_datasets, ) from xarray.testing import assert_equal, assert_identical @@ -92,7 +92,7 @@ def test_checking_from_root(self, create_test_datatree): class TestMapOverSubTree: def test_no_trees_passed(self): - @map_over_subtree + @map_over_datasets def times_ten(ds): return 10.0 * ds @@ -104,7 +104,7 @@ def test_not_isomorphic(self, create_test_datatree): dt2 = create_test_datatree() dt2["set1/set2/extra"] = xr.DataTree(name="extra") - @map_over_subtree + @map_over_datasets def times_ten(ds1, ds2): return ds1 * ds2 @@ -115,7 +115,7 @@ def test_no_trees_returned(self, create_test_datatree): dt1 = create_test_datatree() dt2 = create_test_datatree() - @map_over_subtree + @map_over_datasets def bad_func(ds1, ds2): return None @@ -125,7 +125,7 @@ def bad_func(ds1, ds2): def test_single_dt_arg(self, create_test_datatree): dt = create_test_datatree() - @map_over_subtree + @map_over_datasets def times_ten(ds): return 10.0 * ds @@ -136,7 +136,7 @@ def times_ten(ds): def test_single_dt_arg_plus_args_and_kwargs(self, create_test_datatree): dt = create_test_datatree() - @map_over_subtree + @map_over_datasets def multiply_then_add(ds, times, add=0.0): return (times * ds) + add @@ -148,7 +148,7 @@ def test_multiple_dt_args(self, create_test_datatree): dt1 = create_test_datatree() dt2 = create_test_datatree() - @map_over_subtree + @map_over_datasets def add(ds1, ds2): return ds1 + ds2 @@ -160,7 +160,7 @@ def test_dt_as_kwarg(self, create_test_datatree): dt1 = create_test_datatree() dt2 = create_test_datatree() - @map_over_subtree + @map_over_datasets def add(ds1, value=0.0): return ds1 + value @@ -171,7 +171,7 @@ def add(ds1, value=0.0): def test_return_multiple_dts(self, create_test_datatree): dt = create_test_datatree() - @map_over_subtree + @map_over_datasets def minmax(ds): return ds.min(), ds.max() @@ -184,7 +184,7 @@ def minmax(ds): def test_return_wrong_type(self, simple_datatree): dt1 = simple_datatree - @map_over_subtree + @map_over_datasets def bad_func(ds1): return "string" @@ -194,7 +194,7 @@ def bad_func(ds1): def test_return_tuple_of_wrong_types(self, simple_datatree): dt1 = simple_datatree - @map_over_subtree + @map_over_datasets def bad_func(ds1): return xr.Dataset(), "string" @@ -205,7 +205,7 @@ def bad_func(ds1): def test_return_inconsistent_number_of_results(self, simple_datatree): dt1 = simple_datatree - @map_over_subtree + @map_over_datasets def bad_func(ds): # Datasets in simple_datatree have different numbers of dims # TODO need to instead return different numbers of Dataset objects for this test to catch the intended error @@ -217,7 +217,7 @@ def bad_func(ds): def test_wrong_number_of_arguments_for_func(self, simple_datatree): dt = simple_datatree - @map_over_subtree + @map_over_datasets def times_ten(ds): return 10.0 * ds @@ -229,7 +229,7 @@ def times_ten(ds): def test_map_single_dataset_against_whole_tree(self, create_test_datatree): dt = create_test_datatree() - @map_over_subtree + @map_over_datasets def nodewise_merge(node_ds, fixed_ds): return xr.merge([node_ds, fixed_ds]) @@ -250,7 +250,7 @@ def multiply_then_add(ds, times, add=0.0): return times * ds + add expected = create_test_datatree(modify=lambda ds: (10.0 * ds) + 2.0) - result_tree = dt.map_over_subtree(multiply_then_add, 10.0, add=2.0) + result_tree = dt.map_over_datasets(multiply_then_add, 10.0, add=2.0) assert_equal(result_tree, expected) def test_discard_ancestry(self, create_test_datatree): @@ -258,7 +258,7 @@ def test_discard_ancestry(self, create_test_datatree): dt = create_test_datatree() subtree = dt["set1"] - @map_over_subtree + @map_over_datasets def times_ten(ds): return 10.0 * ds @@ -276,7 +276,7 @@ def check_for_data(ds): assert len(ds.variables) != 0 return ds - dt.map_over_subtree(check_for_data) + dt.map_over_datasets(check_for_data) def test_keep_attrs_on_empty_nodes(self, create_test_datatree): # GH278 @@ -286,7 +286,7 @@ def test_keep_attrs_on_empty_nodes(self, create_test_datatree): def empty_func(ds): return ds - result = dt.map_over_subtree(empty_func) + result = dt.map_over_datasets(empty_func) assert result["set1/set2"].attrs == dt["set1/set2"].attrs @pytest.mark.xfail( @@ -304,17 +304,17 @@ def fail_on_specific_node(ds): with pytest.raises( ValueError, match="Raised whilst mapping function over node /set1" ): - dt.map_over_subtree(fail_on_specific_node) + dt.map_over_datasets(fail_on_specific_node) def test_inherited_coordinates_with_index(self): root = xr.Dataset(coords={"x": [1, 2]}) child = xr.Dataset({"foo": ("x", [0, 1])}) # no coordinates tree = xr.DataTree.from_dict({"/": root, "/child": child}) - actual = tree.map_over_subtree(lambda ds: ds) # identity + actual = tree.map_over_datasets(lambda ds: ds) # identity assert isinstance(actual, xr.DataTree) assert_identical(tree, actual) - actual_child = actual.children["child"].to_dataset(inherited=False) + actual_child = actual.children["child"].to_dataset(inherit=False) assert_identical(actual_child, child) @@ -338,7 +338,7 @@ def test_construct_using_type(self): def weighted_mean(ds): return ds.weighted(ds.area).mean(["x", "y"]) - dt.map_over_subtree(weighted_mean) + dt.map_over_datasets(weighted_mean) def test_alter_inplace_forbidden(self): simpsons = xr.DataTree.from_dict( @@ -359,10 +359,10 @@ def fast_forward(ds: xr.Dataset, years: float) -> xr.Dataset: return ds with pytest.raises(AttributeError): - simpsons.map_over_subtree(fast_forward, years=10) + simpsons.map_over_datasets(fast_forward, years=10) @pytest.mark.xfail class TestMapOverSubTreeInplace: - def test_map_over_subtree_inplace(self): + def test_map_over_datasets_inplace(self): raise NotImplementedError diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 9ef4a688302..f62fbb63cb5 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -283,16 +283,16 @@ def test_infix_dims_errors(supplied, all_): pytest.param(..., ..., id="ellipsis"), ], ) -def test_parse_dims(dim, expected) -> None: +def test_parse_dims_as_tuple(dim, expected) -> None: all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables - actual = utils.parse_dims(dim, all_dims, replace_none=False) + actual = utils.parse_dims_as_tuple(dim, all_dims, replace_none=False) assert actual == expected def test_parse_dims_set() -> None: all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables dim = {"a", 1} - actual = utils.parse_dims(dim, all_dims) + actual = utils.parse_dims_as_tuple(dim, all_dims) assert set(actual) == dim @@ -301,7 +301,7 @@ def test_parse_dims_set() -> None: ) def test_parse_dims_replace_none(dim: None | EllipsisType) -> None: all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables - actual = utils.parse_dims(dim, all_dims, replace_none=True) + actual = utils.parse_dims_as_tuple(dim, all_dims, replace_none=True) assert actual == all_dims @@ -316,7 +316,7 @@ def test_parse_dims_replace_none(dim: None | EllipsisType) -> None: def test_parse_dims_raises(dim) -> None: all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables with pytest.raises(ValueError, match="'x'"): - utils.parse_dims(dim, all_dims, check_exists=True) + utils.parse_dims_as_tuple(dim, all_dims, check_exists=True) @pytest.mark.parametrize( diff --git a/xarray/util/generate_aggregations.py b/xarray/util/generate_aggregations.py index d2fc4f6d4e2..089ef558581 100644 --- a/xarray/util/generate_aggregations.py +++ b/xarray/util/generate_aggregations.py @@ -263,7 +263,7 @@ class DataStructure: create_example: str example_var_name: str numeric_only: bool = False - see_also_modules: tuple[str] = tuple + see_also_modules: tuple[str, ...] = tuple class Method: @@ -287,13 +287,13 @@ def __init__( self.additional_notes = additional_notes if bool_reduce: self.array_method = f"array_{name}" - self.np_example_array = """ - ... np.array([True, True, True, True, True, False], dtype=bool)""" + self.np_example_array = ( + """np.array([True, True, True, True, True, False], dtype=bool)""" + ) else: self.array_method = name - self.np_example_array = """ - ... np.array([1, 2, 3, 0, 2, np.nan])""" + self.np_example_array = """np.array([1, 2, 3, 0, 2, np.nan])""" @dataclass @@ -541,10 +541,27 @@ def generate_code(self, method, has_keep_attrs): ) +DATATREE_OBJECT = DataStructure( + name="DataTree", + create_example=""" + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", {example_array})), + ... coords=dict( + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... )""", + example_var_name="dt", + numeric_only=True, + see_also_modules=("Dataset", "DataArray"), +) DATASET_OBJECT = DataStructure( name="Dataset", create_example=""" - >>> da = xr.DataArray({example_array}, + >>> da = xr.DataArray( + ... {example_array}, ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), @@ -559,7 +576,8 @@ def generate_code(self, method, has_keep_attrs): DATAARRAY_OBJECT = DataStructure( name="DataArray", create_example=""" - >>> da = xr.DataArray({example_array}, + >>> da = xr.DataArray( + ... {example_array}, ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), @@ -570,6 +588,15 @@ def generate_code(self, method, has_keep_attrs): numeric_only=False, see_also_modules=("Dataset",), ) +DATATREE_GENERATOR = GenericAggregationGenerator( + cls="", + datastructure=DATATREE_OBJECT, + methods=AGGREGATION_METHODS, + docref="agg", + docref_description="reduction or aggregation operations", + example_call_preamble="", + definition_preamble=AGGREGATIONS_PREAMBLE, +) DATASET_GENERATOR = GenericAggregationGenerator( cls="", datastructure=DATASET_OBJECT, @@ -634,7 +661,7 @@ def generate_code(self, method, has_keep_attrs): create_example=""" >>> from xarray.namedarray.core import NamedArray >>> na = NamedArray( - ... "x",{example_array}, + ... "x", {example_array} ... )""", example_var_name="na", numeric_only=False, @@ -670,6 +697,7 @@ def write_methods(filepath, generators, preamble): write_methods( filepath=p.parent / "xarray" / "xarray" / "core" / "_aggregations.py", generators=[ + DATATREE_GENERATOR, DATASET_GENERATOR, DATAARRAY_GENERATOR, DATASET_GROUPBY_GENERATOR,