Skip to content

Commit

Permalink
Add zip_subtrees for paired iteration over DataTrees
Browse files Browse the repository at this point in the history
This should be used for implementing DataTree arithmetic inside
map_over_datasets, so the result does not depend on the order in which
child nodes are defined.

I have also added a minimal implementation of breadth-first-search with
an explicit queue the current recursion based solution in
xarray.core.iterators (which has been removed). The new implementation
is also slightly faster in my microbenchmark:

    In [1]: import xarray as xr

    In [2]: tree = xr.DataTree.from_dict({f"/x{i}": None for i in range(100)})

    In [3]: %timeit _ = list(tree.subtree)
    # on main
    87.2 μs ± 394 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

    # with this branch
    55.1 μs ± 294 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
  • Loading branch information
shoyer committed Oct 15, 2024
1 parent c057d13 commit bde843c
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 151 deletions.
5 changes: 4 additions & 1 deletion xarray/core/datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def check_isomorphic(
Also optionally raised if their structure is isomorphic, but the names of any two
respective nodes are not equal.
"""
# TODO: remove require_names_equal and check_from_root. Instead, check that
# all child nodes match, in any order, which will suffice once
# map_over_datasets switches to use zip_subtrees.

if not isinstance(a, TreeNode):
raise TypeError(f"Argument `a` is not a tree, it is of type {type(a)}")
Expand All @@ -68,7 +71,7 @@ def check_isomorphic(

diff = diff_treestructure(a, b, require_names_equal=require_names_equal)

if diff:
if diff is not None:
raise TreeIsomorphismError("DataTree objects are not isomorphic:\n" + diff)


Expand Down
29 changes: 20 additions & 9 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from xarray.core.datatree_render import RenderDataTree
from xarray.core.duck_array_ops import array_equiv, astype
from xarray.core.indexing import MemoryCachedArray
from xarray.core.iterators import LevelOrderIter
from xarray.core.options import OPTIONS, _get_boolean_with_default
from xarray.core.utils import is_duck_array
from xarray.namedarray.pycompat import array_type, to_duck_array, to_numpy
Expand Down Expand Up @@ -981,16 +980,28 @@ def diff_array_repr(a, b, compat):
return "\n".join(summary)


def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> str:
def diff_treestructure(
a: DataTree, b: DataTree, require_names_equal: bool
) -> str | None:
"""
Return a summary of why two trees are not isomorphic.
If they are isomorphic return an empty string.
If they are isomorphic return None.
"""
# .subtrees walks nodes in breadth-first-order, in order to produce as
# shallow of a diff as possible

# TODO: switch zip(a.subtree, b.subtree) to zip_subtrees(a, b), and only
# check that child node names match, e.g.,
# for node_a, node_b in zip_subtrees(a, b):
# if node_a.children.keys() != node_b.children.keys():
# diff = dedent(
# f"""\
# Node {node_a.path!r} in the left object has children {list(node_a.children.keys())}
# Node {node_b.path!r} in the right object has children {list(node_b.children.keys())}"""
# )
# return diff

# Walking nodes in "level-order" fashion means walking down from the root breadth-first.
# Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree
# (which it is so long as children are stored in a tuple or list rather than in a set).
for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b), strict=True):
for node_a, node_b in zip(a.subtree, b.subtree, strict=True):
path_a, path_b = node_a.path, node_b.path

if require_names_equal and node_a.name != node_b.name:
Expand All @@ -1009,7 +1020,7 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s
)
return diff

return ""
return None


def diff_dataset_repr(a, b, compat):
Expand Down Expand Up @@ -1063,7 +1074,7 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat):

# If the trees structures are different there is no point comparing each node
# TODO we could show any differences in nodes up to the first place that structure differs?
if treestructure_diff or compat == "isomorphic":
if treestructure_diff is not None or compat == "isomorphic":
summary.append("\n" + treestructure_diff)
else:
nodewise_diff = diff_nodewise_summary(a, b, compat)
Expand Down
124 changes: 0 additions & 124 deletions xarray/core/iterators.py

This file was deleted.

50 changes: 46 additions & 4 deletions xarray/core/treenode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import collections
import sys
from collections.abc import Iterator, Mapping
from pathlib import PurePosixPath
Expand Down Expand Up @@ -400,15 +401,18 @@ def subtree(self: Tree) -> Iterator[Tree]:
"""
An iterator over all nodes in this tree, including both self and all descendants.
Iterates depth-first.
Iterates bredth-first.
See Also
--------
DataTree.descendants
"""
from xarray.core.iterators import LevelOrderIter

return LevelOrderIter(self)
# https://en.wikipedia.org/wiki/Breadth-first_search#Pseudocode
queue = collections.deque([self])
while queue:
node = queue.popleft()
yield node
queue.extend(node.children.values())

@property
def descendants(self: Tree) -> tuple[Tree, ...]:
Expand Down Expand Up @@ -773,3 +777,41 @@ def _path_to_ancestor(self, ancestor: NamedNode) -> NodePath:
generation_gap = list(parents_paths).index(ancestor.path)
path_upwards = "../" * generation_gap if generation_gap > 0 else "."
return NodePath(path_upwards)


def zip_subtrees(*trees: AnyNamedNode) -> Iterator[tuple[AnyNamedNode, ...]]:
"""Iterate over aligned subtrees in breadth-first order.
Parameters:
-----------
*trees : Tree
Trees to iterate over.
Yields
------
Tuples of matching subtrees.
"""
# https://en.wikipedia.org/wiki/Breadth-first_search#Pseudocode
queue = collections.deque([trees])

while queue:
active_nodes = queue.popleft()

# yield before raising an error, in case the caller chooses to exit
# iteration early
yield active_nodes

first_node = active_nodes[0]
if any(
sibling.children.keys() != first_node.children.keys()
for sibling in active_nodes[1:]
):
child_summary = " vs ".join(
str(list(node.children)) for node in active_nodes
)
raise ValueError(
f"children at {first_node.path!r} do not match: {child_summary}"
)

for name in first_node.children:
queue.append(tuple(node.children[name] for node in active_nodes))
55 changes: 42 additions & 13 deletions xarray/tests/test_treenode.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

from collections.abc import Iterator
from typing import cast
import re

import pytest

from xarray.core.iterators import LevelOrderIter
from xarray.core.treenode import InvalidTreeError, NamedNode, NodePath, TreeNode
from xarray.core.treenode import (
InvalidTreeError,
NamedNode,
NodePath,
TreeNode,
zip_subtrees,
)


class TestFamilyTree:
Expand Down Expand Up @@ -299,15 +303,12 @@ def create_test_tree() -> tuple[NamedNode, NamedNode]:
return a, f


class TestIterators:
class TestZipSubtrees:

def test_levelorderiter(self):
def test_one_tree(self):
root, _ = create_test_tree()
result: list[str | None] = [
node.name for node in cast(Iterator[NamedNode], LevelOrderIter(root))
]
expected = [
"a", # root Node is unnamed
"a",
"b",
"c",
"d",
Expand All @@ -317,8 +318,37 @@ def test_levelorderiter(self):
"g",
"i",
]
result = [node[0].name for node in zip_subtrees(root)]
assert result == expected

def test_different_order(self):
first: NamedNode = NamedNode(
name="a", children={"b": NamedNode(), "c": NamedNode()}
)
second: NamedNode = NamedNode(
name="a", children={"c": NamedNode(), "b": NamedNode()}
)
assert [node.name for node in first.subtree] == ["a", "b", "c"]
assert [node.name for node in second.subtree] == ["a", "c", "b"]
assert [(x.name, y.name) for x, y in zip_subtrees(first, second)] == [
("a", "a"),
("b", "b"),
("c", "c"),
]

def test_different_structure(self):
first: NamedNode = NamedNode(name="a", children={"b": NamedNode()})
second: NamedNode = NamedNode(name="a", children={"c": NamedNode()})
it = zip_subtrees(first, second)

x, y = next(it)
assert x.name == y.name == "a"

with pytest.raises(
ValueError, match=re.escape(r"children at '/' do not match: ['b'] vs ['c']")
):
next(it)


class TestAncestry:

Expand All @@ -343,7 +373,6 @@ def test_ancestors(self):

def test_subtree(self):
root, _ = create_test_tree()
subtree = root.subtree
expected = [
"a",
"b",
Expand All @@ -355,8 +384,8 @@ def test_subtree(self):
"g",
"i",
]
for node, expected_name in zip(subtree, expected, strict=True):
assert node.name == expected_name
actual = [node.name for node in root.subtree]
assert expected == actual

def test_descendants(self):
root, _ = create_test_tree()
Expand Down

0 comments on commit bde843c

Please sign in to comment.