Skip to content

Commit

Permalink
Support alternative names for the root node in DataTree.from_dict (#9638
Browse files Browse the repository at this point in the history
)
  • Loading branch information
shoyer authored Oct 17, 2024
1 parent de3fce8 commit 7046255
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 11 deletions.
33 changes: 22 additions & 11 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,10 +1104,12 @@ def from_dict(
d : dict-like
A mapping from path names to xarray.Dataset or DataTree objects.
Path names are to be given as unix-like path. If path names containing more than one
part are given, new tree nodes will be constructed as necessary.
Path names are to be given as unix-like path. If path names
containing more than one part are given, new tree nodes will be
constructed as necessary.
To assign data to the root node of the tree use "/" as the path.
To assign data to the root node of the tree use "", ".", "/" or "./"
as the path.
name : Hashable | None, optional
Name for the root node of the tree. Default is None.
Expand All @@ -1119,17 +1121,27 @@ def from_dict(
-----
If your dictionary is nested you will need to flatten it before using this method.
"""

# First create the root node
# Find any values corresponding to the root
d_cast = dict(d)
root_data = d_cast.pop("/", None)
root_data = None
for key in ("", ".", "/", "./"):
if key in d_cast:
if root_data is not None:
raise ValueError(
"multiple entries found corresponding to the root node"
)
root_data = d_cast.pop(key)

# Create the root node
if isinstance(root_data, DataTree):
obj = root_data.copy()
obj.name = name
elif root_data is None or isinstance(root_data, Dataset):
obj = cls(name=name, dataset=root_data, children=None)
else:
raise TypeError(
f'root node data (at "/") must be a Dataset or DataTree, got {type(root_data)}'
f'root node data (at "", ".", "/" or "./") must be a Dataset '
f"or DataTree, got {type(root_data)}"
)

def depth(item) -> int:
Expand All @@ -1141,11 +1153,10 @@ def depth(item) -> int:
# Sort keys by depth so as to insert nodes from root first (see GH issue #9276)
for path, data in sorted(d_cast.items(), key=depth):
# Create and set new node
node_name = NodePath(path).name
if isinstance(data, DataTree):
new_node = data.copy()
elif isinstance(data, Dataset) or data is None:
new_node = cls(name=node_name, dataset=data)
new_node = cls(dataset=data)
else:
raise TypeError(f"invalid values: {data}")
obj._set_item(
Expand Down Expand Up @@ -1683,7 +1694,7 @@ def reduce(
numeric_only=numeric_only,
**kwargs,
)
path = "/" if node is self else node.relative_to(self)
path = node.relative_to(self)
result[path] = node_result
return type(self).from_dict(result, name=self.name)

Expand Down Expand Up @@ -1718,7 +1729,7 @@ def _selective_indexing(
# with a scalar) can also create scalar coordinates, which
# need to be explicitly removed.
del node_result.coords[k]
path = "/" if node is self else node.relative_to(self)
path = node.relative_to(self)
result[path] = node_result
return type(self).from_dict(result, name=self.name)

Expand Down
42 changes: 42 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,48 @@ def test_array_values(self) -> None:
with pytest.raises(TypeError):
DataTree.from_dict(data) # type: ignore[arg-type]

def test_relative_paths(self) -> None:
tree = DataTree.from_dict({".": None, "foo": None, "./bar": None, "x/y": None})
paths = [node.path for node in tree.subtree]
assert paths == [
"/",
"/foo",
"/bar",
"/x",
"/x/y",
]

def test_root_keys(self):
ds = Dataset({"x": 1})
expected = DataTree(dataset=ds)

actual = DataTree.from_dict({"": ds})
assert_identical(actual, expected)

actual = DataTree.from_dict({".": ds})
assert_identical(actual, expected)

actual = DataTree.from_dict({"/": ds})
assert_identical(actual, expected)

actual = DataTree.from_dict({"./": ds})
assert_identical(actual, expected)

with pytest.raises(
ValueError, match="multiple entries found corresponding to the root node"
):
DataTree.from_dict({"": ds, "/": ds})

def test_name(self):
tree = DataTree.from_dict({"/": None}, name="foo")
assert tree.name == "foo"

tree = DataTree.from_dict({"/": DataTree()}, name="foo")
assert tree.name == "foo"

tree = DataTree.from_dict({"/": DataTree(name="bar")}, name="foo")
assert tree.name == "foo"


class TestDatasetView:
def test_view_contents(self) -> None:
Expand Down

0 comments on commit 7046255

Please sign in to comment.