From 7486f4ee794a7b8ea7185d912a226a9d97c3ec19 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 16 Oct 2024 02:41:25 +0900 Subject: [PATCH] Bug fixes for DataTree indexing and aggregation (#9626) * Bug fixes for DataTree indexing and aggregation My implementation of indexing and aggregation was incorrect on child nodes, re-creating the child nodes from the root. There was also another bug when indexing inherited coordinates that meant formerly inherited coordinates were incorrectly dropped from results. * disable broken test --- xarray/core/datatree.py | 23 +++++++++++++---------- xarray/tests/test_datatree.py | 32 ++++++++++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 452a1af2d3b..702799278a3 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1681,7 +1681,8 @@ def reduce( numeric_only=numeric_only, **kwargs, ) - result[node.path] = node_result + path = "/" if node is self else node.relative_to(self) + result[path] = node_result return type(self).from_dict(result, name=self.name) def _selective_indexing( @@ -1706,15 +1707,17 @@ def _selective_indexing( # Ideally, we would avoid creating such coordinates in the first # place, but that would require implementing indexing operations at # the Variable instead of the Dataset level. - for k in node_indexers: - if k not in node._node_coord_variables and k in node_result.coords: - # We remove all inherited coordinates. Coordinates - # corresponding to an index would be de-duplicated by - # _deduplicate_inherited_coordinates(), but indexing (e.g., - # with a scalar) can also create scalar coordinates, which - # need to be explicitly removed. - del node_result.coords[k] - result[node.path] = node_result + if node is not self: + for k in node_indexers: + if k not in node._node_coord_variables and k in node_result.coords: + # We remove all inherited coordinates. Coordinates + # corresponding to an index would be de-duplicated by + # _deduplicate_inherited_coordinates(), but indexing (e.g., + # with a scalar) can also create scalar coordinates, which + # need to be explicitly removed. + del node_result.coords[k] + path = "/" if node is self else node.relative_to(self) + result[path] = node_result return type(self).from_dict(result, name=self.name) def isel( diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index c43f53a22c1..97fe8212c86 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1651,7 +1651,18 @@ def test_isel_inherited(self): assert_equal(actual, expected) actual = tree.isel(x=slice(None)) - assert_equal(actual, tree) + + # TODO: re-enable after the fix to copy() from #9628 is submitted + # actual = tree.children["child"].isel(x=slice(None)) + # expected = tree.children["child"].copy() + # assert_identical(actual, expected) + + actual = tree.children["child"].isel(x=0) + expected = DataTree( + dataset=xr.Dataset({"foo": 3}, coords={"x": 1}), + name="child", + ) + assert_identical(actual, expected) def test_sel(self): tree = DataTree.from_dict( @@ -1667,7 +1678,14 @@ def test_sel(self): } ) actual = tree.sel(x=2) - assert_equal(actual, expected) + assert_identical(actual, expected) + + actual = tree.children["first"].sel(x=2) + expected = DataTree( + dataset=xr.Dataset({"a": 2}, coords={"x": 2}), + name="first", + ) + assert_identical(actual, expected) class TestAggregations: @@ -1739,6 +1757,16 @@ def test_dim_argument(self): ): dt.mean("invalid") + def test_subtree(self): + tree = DataTree.from_dict( + { + "/child": Dataset({"a": ("x", [1, 2])}), + } + ) + expected = DataTree(dataset=Dataset({"a": 1.5}), name="child") + actual = tree.children["child"].mean() + assert_identical(expected, actual) + class TestOps: def test_unary_op(self):