diff --git a/doc/api.rst b/doc/api.rst index 63fb59bc5e0..add1eed8944 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -761,16 +761,17 @@ Compare one ``DataTree`` object to another. DataTree.equals DataTree.identical -.. Indexing -.. -------- +Indexing +-------- -.. Index into all nodes in the subtree simultaneously. +Index into all nodes in the subtree simultaneously. -.. .. autosummary:: -.. :toctree: generated/ +.. autosummary:: + :toctree: generated/ + + DataTree.isel + DataTree.sel -.. DataTree.isel -.. DataTree.sel .. DataTree.drop_sel .. DataTree.drop_isel .. DataTree.head diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index f8da26c28b5..04683a511f3 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -32,11 +32,13 @@ from xarray.core.merge import dataset_update_method from xarray.core.options import OPTIONS as XR_OPTS from xarray.core.treenode import NamedNode, NodePath +from xarray.core.types import Self from xarray.core.utils import ( Default, FilteredMapping, Frozen, _default, + drop_dims_from_indexers, either_dict_or_kwargs, maybe_wrap_array, ) @@ -54,7 +56,12 @@ from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes from xarray.core.merge import CoercibleMapping, CoercibleValue - from xarray.core.types import ErrorOptions, NetcdfWriteModes, ZarrWriteModes + from xarray.core.types import ( + ErrorOptions, + ErrorOptionsWithWarn, + NetcdfWriteModes, + ZarrWriteModes, + ) # """ # DEVELOPERS' NOTE @@ -1087,7 +1094,7 @@ def from_dict( d: Mapping[str, Dataset | DataTree | None], /, name: str | None = None, - ) -> DataTree: + ) -> Self: """ Create a datatree from a dictionary of data objects, organised by paths into the tree. @@ -1607,3 +1614,180 @@ def to_zarr( compute=compute, **kwargs, ) + + def _selective_indexing( + self, + func: Callable[[Dataset, Mapping[Any, Any]], Dataset], + indexers: Mapping[Any, Any], + missing_dims: ErrorOptionsWithWarn = "raise", + ) -> Self: + """Apply an indexing operation over the subtree, handling missing + dimensions and inherited coordinates gracefully by only applying + indexing at each node selectively. + """ + all_dims = set() + for node in self.subtree: + all_dims.update(node._node_dims) + indexers = drop_dims_from_indexers(indexers, all_dims, missing_dims) + + result = {} + for node in self.subtree: + node_indexers = {k: v for k, v in indexers.items() if k in node.dims} + node_result = func(node.dataset, node_indexers) + # Indexing datasets corresponding to each node results in redundant + # coordinates when indexes from a parent node are inherited. + # Ideally, we would avoid creating such coordinates in the first + # place, but that would require implementing indexing operations at + # the Variable instead of the Dataset level. + for k in node_indexers: + if k not in node._node_coord_variables and k in node_result.coords: + # We remove all inherited coordinates. Coordinates + # corresponding to an index would be de-duplicated by + # _deduplicate_inherited_coordinates(), but indexing (e.g., + # with a scalar) can also create scalar coordinates, which + # need to be explicitly removed. + del node_result.coords[k] + result[node.path] = node_result + return type(self).from_dict(result, name=self.name) + + def isel( + self, + indexers: Mapping[Any, Any] | None = None, + drop: bool = False, + missing_dims: ErrorOptionsWithWarn = "raise", + **indexers_kwargs: Any, + ) -> Self: + """Returns a new data tree with each array indexed along the specified + dimension(s). + + This method selects values from each array using its `__getitem__` + method, except this method does not require knowing the order of + each array's dimensions. + + Parameters + ---------- + indexers : dict, optional + A dict with keys matching dimensions and values given + by integers, slice objects or arrays. + indexer can be a integer, slice, array-like or DataArray. + If DataArrays are passed as indexers, xarray-style indexing will be + carried out. See :ref:`indexing` for the details. + One of indexers or indexers_kwargs must be provided. + drop : bool, default: False + If ``drop=True``, drop coordinates variables indexed by integers + instead of making them scalar. + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + Dataset: + - "raise": raise an exception + - "warn": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions + + **indexers_kwargs : {dim: indexer, ...}, optional + The keyword arguments form of ``indexers``. + One of indexers or indexers_kwargs must be provided. + + Returns + ------- + obj : DataTree + A new DataTree with the same contents as this data tree, except each + array and dimension is indexed by the appropriate indexers. + If indexer DataArrays have coordinates that do not conflict with + this object, then these coordinates will be attached. + In general, each array's data will be a view of the array's data + in this dataset, unless vectorized indexing was triggered by using + an array indexer, in which case the data will be a copy. + + See Also + -------- + DataTree.sel + Dataset.isel + """ + + def apply_indexers(dataset, node_indexers): + return dataset.isel(node_indexers, drop=drop) + + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel") + return self._selective_indexing( + apply_indexers, indexers, missing_dims=missing_dims + ) + + def sel( + self, + indexers: Mapping[Any, Any] | None = None, + method: str | None = None, + tolerance: int | float | Iterable[int | float] | None = None, + drop: bool = False, + **indexers_kwargs: Any, + ) -> Self: + """Returns a new data tree with each array indexed by tick labels + along the specified dimension(s). + + In contrast to `DataTree.isel`, indexers for this method should use + labels instead of integers. + + Under the hood, this method is powered by using pandas's powerful Index + objects. This makes label based indexing essentially just as fast as + using integer indexing. + + It also means this method uses pandas's (well documented) logic for + indexing. This means you can use string shortcuts for datetime indexes + (e.g., '2000-01' to select all values in January 2000). It also means + that slices are treated as inclusive of both the start and stop values, + unlike normal Python indexing. + + Parameters + ---------- + indexers : dict, optional + A dict with keys matching dimensions and values given + by scalars, slices or arrays of tick labels. For dimensions with + multi-index, the indexer may also be a dict-like object with keys + matching index level names. + If DataArrays are passed as indexers, xarray-style indexing will be + carried out. See :ref:`indexing` for the details. + One of indexers or indexers_kwargs must be provided. + method : {None, "nearest", "pad", "ffill", "backfill", "bfill"}, optional + Method to use for inexact matches: + + * None (default): only exact matches + * pad / ffill: propagate last valid index value forward + * backfill / bfill: propagate next valid index value backward + * nearest: use nearest valid index value + tolerance : optional + Maximum distance between original and new labels for inexact + matches. The values of the index at the matching locations must + satisfy the equation ``abs(index[indexer] - target) <= tolerance``. + drop : bool, optional + If ``drop=True``, drop coordinates variables in `indexers` instead + of making them scalar. + **indexers_kwargs : {dim: indexer, ...}, optional + The keyword arguments form of ``indexers``. + One of indexers or indexers_kwargs must be provided. + + Returns + ------- + obj : DataTree + A new DataTree with the same contents as this data tree, except each + variable and dimension is indexed by the appropriate indexers. + If indexer DataArrays have coordinates that do not conflict with + this object, then these coordinates will be attached. + In general, each array's data will be a view of the array's data + in this dataset, unless vectorized indexing was triggered by using + an array indexer, in which case the data will be a copy. + + See Also + -------- + DataTree.isel + Dataset.sel + """ + + def apply_indexers(dataset, node_indexers): + # TODO: reimplement in terms of map_index_queries(), to avoid + # redundant look-ups of integer positions from labels (via indexes) + # on child nodes. + return dataset.sel( + node_indexers, method=method, tolerance=tolerance, drop=drop + ) + + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel") + return self._selective_indexing(apply_indexers, indexers) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 3a3afb0647a..1de53a35311 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -971,7 +971,6 @@ def test_ipython_key_completions(self, create_test_datatree): var_keys = list(dt.variables.keys()) assert all(var_key in key_completions for var_key in var_keys) - @pytest.mark.xfail(reason="sel not implemented yet") def test_operation_with_attrs_but_no_data(self): # tests bug from xarray-datatree GH262 xs = xr.Dataset({"testvar": xr.DataArray(np.ones((2, 3)))}) @@ -1561,26 +1560,95 @@ def test_filter(self): assert_identical(elders, expected) -class TestDSMethodInheritance: - @pytest.mark.xfail(reason="isel not implemented yet") - def test_dataset_method(self): - ds = xr.Dataset({"a": ("x", [1, 2, 3])}) - dt = DataTree.from_dict( +class TestIndexing: + + def test_isel_siblings(self): + tree = DataTree.from_dict( { - "/": ds, - "/results": ds, + "/first": xr.Dataset({"a": ("x", [1, 2])}), + "/second": xr.Dataset({"b": ("x", [1, 2, 3])}), } ) expected = DataTree.from_dict( { - "/": ds.isel(x=1), - "/results": ds.isel(x=1), + "/first": xr.Dataset({"a": 2}), + "/second": xr.Dataset({"b": 3}), } ) + actual = tree.isel(x=-1) + assert_equal(actual, expected) - result = dt.isel(x=1) - assert_equal(result, expected) + expected = DataTree.from_dict( + { + "/first": xr.Dataset({"a": ("x", [1])}), + "/second": xr.Dataset({"b": ("x", [1])}), + } + ) + actual = tree.isel(x=slice(1)) + assert_equal(actual, expected) + + actual = tree.isel(x=[0]) + assert_equal(actual, expected) + + actual = tree.isel(x=slice(None)) + assert_equal(actual, tree) + + def test_isel_inherited(self): + tree = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1, 2]}), + "/child": xr.Dataset({"foo": ("x", [3, 4])}), + } + ) + + expected = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": 2}), + "/child": xr.Dataset({"foo": 4}), + } + ) + actual = tree.isel(x=-1) + assert_equal(actual, expected) + + expected = DataTree.from_dict( + { + "/child": xr.Dataset({"foo": 4}), + } + ) + actual = tree.isel(x=-1, drop=True) + assert_equal(actual, expected) + + expected = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1]}), + "/child": xr.Dataset({"foo": ("x", [3])}), + } + ) + actual = tree.isel(x=[0]) + assert_equal(actual, expected) + + actual = tree.isel(x=slice(None)) + assert_equal(actual, tree) + + def test_sel(self): + tree = DataTree.from_dict( + { + "/first": xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"x": [1, 2, 3]}), + "/second": xr.Dataset({"b": ("x", [4, 5])}, coords={"x": [2, 3]}), + } + ) + expected = DataTree.from_dict( + { + "/first": xr.Dataset({"a": 2}, coords={"x": 2}), + "/second": xr.Dataset({"b": 4}, coords={"x": 2}), + } + ) + actual = tree.sel(x=2) + assert_equal(actual, expected) + + +class TestDSMethodInheritance: @pytest.mark.xfail(reason="reduce methods not implemented yet") def test_reduce_method(self):