From 100712d05af0e512154122b25b186887181a3762 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 6 Oct 2024 12:51:26 +0900 Subject: [PATCH 1/6] Implement DataTree.isel and DataTree.sel --- xarray/core/datatree.py | 67 +++++++++++++++++++++++++++- xarray/tests/test_datatree.py | 83 ++++++++++++++++++++++++++++++----- 2 files changed, 137 insertions(+), 13 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index f8da26c28b5..cdf46059b6c 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -32,11 +32,13 @@ from xarray.core.merge import dataset_update_method from xarray.core.options import OPTIONS as XR_OPTS from xarray.core.treenode import NamedNode, NodePath +from xarray.core.types import Self from xarray.core.utils import ( Default, FilteredMapping, Frozen, _default, + drop_dims_from_indexers, either_dict_or_kwargs, maybe_wrap_array, ) @@ -54,7 +56,12 @@ from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes from xarray.core.merge import CoercibleMapping, CoercibleValue - from xarray.core.types import ErrorOptions, NetcdfWriteModes, ZarrWriteModes + from xarray.core.types import ( + ErrorOptions, + ErrorOptionsWithWarn, + NetcdfWriteModes, + ZarrWriteModes, + ) # """ # DEVELOPERS' NOTE @@ -1607,3 +1614,61 @@ def to_zarr( compute=compute, **kwargs, ) + + def _selective_indexing( + self, + func: Callable[[Dataset, Mapping[Any, Any]]], + indexers: Mapping[Any, Any], + missing_dims: ErrorOptionsWithWarn = "raise", + ) -> Self: + all_dims = set() + for node in self.subtree: + all_dims.update(node._node_dims) + indexers = drop_dims_from_indexers(indexers, all_dims, missing_dims) + + result = {} + for node in self.subtree: + node_indexers = {k: v for k, v in indexers.items() if k in node.dims} + node_result = func(node.dataset, node_indexers) + for k in node_indexers: + if k not in node.coords and k in node_result.coords: + del node_result.coords[k] + result[node.path] = node_result + return type(self).from_dict(result, name=self.name) # type: ignore + + def isel( + self, + indexers: Mapping[Any, Any] | None = None, + drop: bool = False, + missing_dims: ErrorOptionsWithWarn = "raise", + **indexers_kwargs: Any, + ) -> Self: + """Positional indexing.""" + + def apply_indexers(dataset, node_indexers): + return dataset.isel(node_indexers, drop=drop) + + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel") + return self._selective_indexing( + apply_indexers, indexers, missing_dims=missing_dims + ) + + def sel( + self, + indexers: Mapping[Any, Any] | None = None, + method: str | None = None, + tolerance: int | float | Iterable[int | float] | None = None, + drop: bool = False, + **indexers_kwargs: Any, + ) -> Self: + """Label-based indexing.""" + + def apply_indexers(dataset, node_indexers): + # TODO: reimplement in terms of map_index_queries(), to avoid + # redundant index look-ups on child nodes + return dataset.sel( + node_indexers, method=method, tolerance=tolerance, drop=drop + ) + + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel") + return self._selective_indexing(apply_indexers, indexers) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 3a3afb0647a..86fbad3526d 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -971,7 +971,6 @@ def test_ipython_key_completions(self, create_test_datatree): var_keys = list(dt.variables.keys()) assert all(var_key in key_completions for var_key in var_keys) - @pytest.mark.xfail(reason="sel not implemented yet") def test_operation_with_attrs_but_no_data(self): # tests bug from xarray-datatree GH262 xs = xr.Dataset({"testvar": xr.DataArray(np.ones((2, 3)))}) @@ -1561,26 +1560,86 @@ def test_filter(self): assert_identical(elders, expected) -class TestDSMethodInheritance: - @pytest.mark.xfail(reason="isel not implemented yet") - def test_dataset_method(self): - ds = xr.Dataset({"a": ("x", [1, 2, 3])}) - dt = DataTree.from_dict( +class TestIndexing: + + def test_isel_siblings(self): + tree = DataTree.from_dict( { - "/": ds, - "/results": ds, + "/first": xr.Dataset({"a": ("x", [1, 2])}), + "/second": xr.Dataset({"b": ("x", [1, 2, 3])}), } ) expected = DataTree.from_dict( { - "/": ds.isel(x=1), - "/results": ds.isel(x=1), + "/first": xr.Dataset({"a": 2}), + "/second": xr.Dataset({"b": 3}), } ) + actual = tree.isel(x=-1) + assert_equal(actual, expected) - result = dt.isel(x=1) - assert_equal(result, expected) + expected = DataTree.from_dict( + { + "/first": xr.Dataset({"a": ("x", [1])}), + "/second": xr.Dataset({"b": ("x", [1])}), + } + ) + actual = tree.isel(x=slice(1)) + assert_equal(actual, expected) + + actual = tree.isel(x=[0]) + assert_equal(actual, expected) + + actual = tree.isel(x=slice(None)) + assert_equal(actual, tree) + + def test_isel_inherited(self): + tree = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1, 2]}), + "/child": xr.Dataset({"foo": ("x", [3, 4])}), + } + ) + + expected = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": 2}), + "/child": xr.Dataset({"foo": 4}), + } + ) + actual = tree.isel(x=-1) + assert_equal(actual, expected) + + expected = DataTree.from_dict( + { + "/child": xr.Dataset({"foo": 4}), + } + ) + actual = tree.isel(x=-1, drop=True) + assert_equal(actual, expected) + + actual = tree.isel(x=slice(None)) + assert_equal(actual, tree) + + def test_sel(self): + tree = DataTree.from_dict( + { + "/first": xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"x": [1, 2, 3]}), + "/second": xr.Dataset({"b": ("x", [4, 5])}, coords={"x": [2, 3]}), + } + ) + expected = DataTree.from_dict( + { + "/first": xr.Dataset({"a": 2}, coords={"x": 2}), + "/second": xr.Dataset({"b": 4}, coords={"x": 2}), + } + ) + actual = tree.sel(x=2) + assert_equal(actual, expected) + + +class TestDSMethodInheritance: @pytest.mark.xfail(reason="reduce methods not implemented yet") def test_reduce_method(self): From e312c105e31fa3d42e7eb2baf1eb81f99f02ca9f Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 7 Oct 2024 20:47:03 +0900 Subject: [PATCH 2/6] add api docs --- doc/api.rst | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 63fb59bc5e0..add1eed8944 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -761,16 +761,17 @@ Compare one ``DataTree`` object to another. DataTree.equals DataTree.identical -.. Indexing -.. -------- +Indexing +-------- -.. Index into all nodes in the subtree simultaneously. +Index into all nodes in the subtree simultaneously. -.. .. autosummary:: -.. :toctree: generated/ +.. autosummary:: + :toctree: generated/ + + DataTree.isel + DataTree.sel -.. DataTree.isel -.. DataTree.sel .. DataTree.drop_sel .. DataTree.drop_isel .. DataTree.head From 9a4c895b6a5271c55be176366ac88ba0a8524e6f Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 7 Oct 2024 21:10:24 +0900 Subject: [PATCH 3/6] fix CI failures --- xarray/core/datatree.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index cdf46059b6c..b909c8e5ade 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1094,7 +1094,7 @@ def from_dict( d: Mapping[str, Dataset | DataTree | None], /, name: str | None = None, - ) -> DataTree: + ) -> Self: """ Create a datatree from a dictionary of data objects, organised by paths into the tree. @@ -1617,7 +1617,7 @@ def to_zarr( def _selective_indexing( self, - func: Callable[[Dataset, Mapping[Any, Any]]], + func: Callable[[Dataset, Mapping[Any, Any]], Dataset], indexers: Mapping[Any, Any], missing_dims: ErrorOptionsWithWarn = "raise", ) -> Self: @@ -1631,10 +1631,10 @@ def _selective_indexing( node_indexers = {k: v for k, v in indexers.items() if k in node.dims} node_result = func(node.dataset, node_indexers) for k in node_indexers: - if k not in node.coords and k in node_result.coords: + if k not in node._node_coord_variables and k in node_result.coords: del node_result.coords[k] result[node.path] = node_result - return type(self).from_dict(result, name=self.name) # type: ignore + return type(self).from_dict(result, name=self.name) def isel( self, From c7fe8feeb54a8b47b6cf98c1b5ad9bf165d007e7 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 8 Oct 2024 11:38:34 +0900 Subject: [PATCH 4/6] add docstrings for DataTree.isel and DataTree.sel --- xarray/core/datatree.py | 108 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 106 insertions(+), 2 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index b909c8e5ade..a9767fa0daa 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1643,7 +1643,52 @@ def isel( missing_dims: ErrorOptionsWithWarn = "raise", **indexers_kwargs: Any, ) -> Self: - """Positional indexing.""" + """Returns a new data tree with each array indexed along the specified + dimension(s). + + This method selects values from each array using its `__getitem__` + method, except this method does not require knowing the order of + each array's dimensions. + + Parameters + ---------- + indexers : dict, optional + A dict with keys matching dimensions and values given + by integers, slice objects or arrays. + indexer can be a integer, slice, array-like or DataArray. + If DataArrays are passed as indexers, xarray-style indexing will be + carried out. See :ref:`indexing` for the details. + One of indexers or indexers_kwargs must be provided. + drop : bool, default: False + If ``drop=True``, drop coordinates variables indexed by integers + instead of making them scalar. + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + Dataset: + - "raise": raise an exception + - "warn": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions + + **indexers_kwargs : {dim: indexer, ...}, optional + The keyword arguments form of ``indexers``. + One of indexers or indexers_kwargs must be provided. + + Returns + ------- + obj : DataTree + A new DataTree with the same contents as this data tree, except each + array and dimension is indexed by the appropriate indexers. + If indexer DataArrays have coordinates that do not conflict with + this object, then these coordinates will be attached. + In general, each array's data will be a view of the array's data + in this dataset, unless vectorized indexing was triggered by using + an array indexer, in which case the data will be a copy. + + See Also + -------- + DataTree.sel + Dataset.isel + """ def apply_indexers(dataset, node_indexers): return dataset.isel(node_indexers, drop=drop) @@ -1661,7 +1706,66 @@ def sel( drop: bool = False, **indexers_kwargs: Any, ) -> Self: - """Label-based indexing.""" + """Returns a new data tree with each array indexed by tick labels + along the specified dimension(s). + + In contrast to `DataTree.isel`, indexers for this method should use + labels instead of integers. + + Under the hood, this method is powered by using pandas's powerful Index + objects. This makes label based indexing essentially just as fast as + using integer indexing. + + It also means this method uses pandas's (well documented) logic for + indexing. This means you can use string shortcuts for datetime indexes + (e.g., '2000-01' to select all values in January 2000). It also means + that slices are treated as inclusive of both the start and stop values, + unlike normal Python indexing. + + Parameters + ---------- + indexers : dict, optional + A dict with keys matching dimensions and values given + by scalars, slices or arrays of tick labels. For dimensions with + multi-index, the indexer may also be a dict-like object with keys + matching index level names. + If DataArrays are passed as indexers, xarray-style indexing will be + carried out. See :ref:`indexing` for the details. + One of indexers or indexers_kwargs must be provided. + method : {None, "nearest", "pad", "ffill", "backfill", "bfill"}, optional + Method to use for inexact matches: + + * None (default): only exact matches + * pad / ffill: propagate last valid index value forward + * backfill / bfill: propagate next valid index value backward + * nearest: use nearest valid index value + tolerance : optional + Maximum distance between original and new labels for inexact + matches. The values of the index at the matching locations must + satisfy the equation ``abs(index[indexer] - target) <= tolerance``. + drop : bool, optional + If ``drop=True``, drop coordinates variables in `indexers` instead + of making them scalar. + **indexers_kwargs : {dim: indexer, ...}, optional + The keyword arguments form of ``indexers``. + One of indexers or indexers_kwargs must be provided. + + Returns + ------- + obj : DataTree + A new DataTree with the same contents as this data tree, except each + variable and dimension is indexed by the appropriate indexers. + If indexer DataArrays have coordinates that do not conflict with + this object, then these coordinates will be attached. + In general, each array's data will be a view of the array's data + in this dataset, unless vectorized indexing was triggered by using + an array indexer, in which case the data will be a copy. + + See Also + -------- + DataTree.isel + Dataset.sel + """ def apply_indexers(dataset, node_indexers): # TODO: reimplement in terms of map_index_queries(), to avoid From 93b781323a018c88964a81ce51135d5b6ff8517e Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 8 Oct 2024 11:47:45 +0900 Subject: [PATCH 5/6] Add comments --- xarray/core/datatree.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index a9767fa0daa..04683a511f3 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1621,6 +1621,10 @@ def _selective_indexing( indexers: Mapping[Any, Any], missing_dims: ErrorOptionsWithWarn = "raise", ) -> Self: + """Apply an indexing operation over the subtree, handling missing + dimensions and inherited coordinates gracefully by only applying + indexing at each node selectively. + """ all_dims = set() for node in self.subtree: all_dims.update(node._node_dims) @@ -1630,8 +1634,18 @@ def _selective_indexing( for node in self.subtree: node_indexers = {k: v for k, v in indexers.items() if k in node.dims} node_result = func(node.dataset, node_indexers) + # Indexing datasets corresponding to each node results in redundant + # coordinates when indexes from a parent node are inherited. + # Ideally, we would avoid creating such coordinates in the first + # place, but that would require implementing indexing operations at + # the Variable instead of the Dataset level. for k in node_indexers: if k not in node._node_coord_variables and k in node_result.coords: + # We remove all inherited coordinates. Coordinates + # corresponding to an index would be de-duplicated by + # _deduplicate_inherited_coordinates(), but indexing (e.g., + # with a scalar) can also create scalar coordinates, which + # need to be explicitly removed. del node_result.coords[k] result[node.path] = node_result return type(self).from_dict(result, name=self.name) @@ -1769,7 +1783,8 @@ def sel( def apply_indexers(dataset, node_indexers): # TODO: reimplement in terms of map_index_queries(), to avoid - # redundant index look-ups on child nodes + # redundant look-ups of integer positions from labels (via indexes) + # on child nodes. return dataset.sel( node_indexers, method=method, tolerance=tolerance, drop=drop ) From f3eecf44e220f5bff8b67212d4693a039efa91d5 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 8 Oct 2024 11:50:28 +0900 Subject: [PATCH 6/6] add another indexing test --- xarray/tests/test_datatree.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 86fbad3526d..1de53a35311 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1619,6 +1619,15 @@ def test_isel_inherited(self): actual = tree.isel(x=-1, drop=True) assert_equal(actual, expected) + expected = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1]}), + "/child": xr.Dataset({"foo": ("x", [3])}), + } + ) + actual = tree.isel(x=[0]) + assert_equal(actual, expected) + actual = tree.isel(x=slice(None)) assert_equal(actual, tree)