Skip to content

Commit

Permalink
Merge branch 'main' into datatree-memo
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Oct 15, 2024
2 parents 73fcba6 + 7486f4e commit e65021d
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 12 deletions.
23 changes: 13 additions & 10 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1683,7 +1683,8 @@ def reduce(
numeric_only=numeric_only,
**kwargs,
)
result[node.path] = node_result
path = "/" if node is self else node.relative_to(self)
result[path] = node_result
return type(self).from_dict(result, name=self.name)

def _selective_indexing(
Expand All @@ -1708,15 +1709,17 @@ def _selective_indexing(
# Ideally, we would avoid creating such coordinates in the first
# place, but that would require implementing indexing operations at
# the Variable instead of the Dataset level.
for k in node_indexers:
if k not in node._node_coord_variables and k in node_result.coords:
# We remove all inherited coordinates. Coordinates
# corresponding to an index would be de-duplicated by
# _deduplicate_inherited_coordinates(), but indexing (e.g.,
# with a scalar) can also create scalar coordinates, which
# need to be explicitly removed.
del node_result.coords[k]
result[node.path] = node_result
if node is not self:
for k in node_indexers:
if k not in node._node_coord_variables and k in node_result.coords:
# We remove all inherited coordinates. Coordinates
# corresponding to an index would be de-duplicated by
# _deduplicate_inherited_coordinates(), but indexing (e.g.,
# with a scalar) can also create scalar coordinates, which
# need to be explicitly removed.
del node_result.coords[k]
path = "/" if node is self else node.relative_to(self)
result[path] = node_result
return type(self).from_dict(result, name=self.name)

def isel(
Expand Down
32 changes: 30 additions & 2 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,18 @@ def test_isel_inherited(self):
assert_equal(actual, expected)

actual = tree.isel(x=slice(None))
assert_equal(actual, tree)

# TODO: re-enable after the fix to copy() from #9628 is submitted
# actual = tree.children["child"].isel(x=slice(None))
# expected = tree.children["child"].copy()
# assert_identical(actual, expected)

actual = tree.children["child"].isel(x=0)
expected = DataTree(
dataset=xr.Dataset({"foo": 3}, coords={"x": 1}),
name="child",
)
assert_identical(actual, expected)

def test_sel(self):
tree = DataTree.from_dict(
Expand All @@ -1667,7 +1678,14 @@ def test_sel(self):
}
)
actual = tree.sel(x=2)
assert_equal(actual, expected)
assert_identical(actual, expected)

actual = tree.children["first"].sel(x=2)
expected = DataTree(
dataset=xr.Dataset({"a": 2}, coords={"x": 2}),
name="first",
)
assert_identical(actual, expected)


class TestAggregations:
Expand Down Expand Up @@ -1739,6 +1757,16 @@ def test_dim_argument(self):
):
dt.mean("invalid")

def test_subtree(self):
tree = DataTree.from_dict(
{
"/child": Dataset({"a": ("x", [1, 2])}),
}
)
expected = DataTree(dataset=Dataset({"a": 1.5}), name="child")
actual = tree.children["child"].mean()
assert_identical(expected, actual)


class TestOps:
def test_unary_op(self):
Expand Down

0 comments on commit e65021d

Please sign in to comment.