diff --git a/xarray/core/datatree_mapping.py b/xarray/core/datatree_mapping.py index 78abf42601d..f3e7ce348b1 100644 --- a/xarray/core/datatree_mapping.py +++ b/xarray/core/datatree_mapping.py @@ -56,6 +56,9 @@ def check_isomorphic( Also optionally raised if their structure is isomorphic, but the names of any two respective nodes are not equal. """ + # TODO: remove require_names_equal and check_from_root. Instead, check that + # all child nodes match, in any order, which will suffice once + # map_over_datasets switches to use zip_subtrees. if not isinstance(a, TreeNode): raise TypeError(f"Argument `a` is not a tree, it is of type {type(a)}") @@ -68,7 +71,7 @@ def check_isomorphic( diff = diff_treestructure(a, b, require_names_equal=require_names_equal) - if diff: + if diff is not None: raise TreeIsomorphismError("DataTree objects are not isomorphic:\n" + diff) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 38dccfa2038..9b0ab8a8ae6 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -21,7 +21,6 @@ from xarray.core.datatree_render import RenderDataTree from xarray.core.duck_array_ops import array_equiv, astype from xarray.core.indexing import MemoryCachedArray -from xarray.core.iterators import LevelOrderIter from xarray.core.options import OPTIONS, _get_boolean_with_default from xarray.core.utils import is_duck_array from xarray.namedarray.pycompat import array_type, to_duck_array, to_numpy @@ -981,16 +980,28 @@ def diff_array_repr(a, b, compat): return "\n".join(summary) -def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> str: +def diff_treestructure( + a: DataTree, b: DataTree, require_names_equal: bool +) -> str | None: """ Return a summary of why two trees are not isomorphic. - If they are isomorphic return an empty string. + If they are isomorphic return None. """ + # .subtrees walks nodes in breadth-first-order, in order to produce as + # shallow of a diff as possible + + # TODO: switch zip(a.subtree, b.subtree) to zip_subtrees(a, b), and only + # check that child node names match, e.g., + # for node_a, node_b in zip_subtrees(a, b): + # if node_a.children.keys() != node_b.children.keys(): + # diff = dedent( + # f"""\ + # Node {node_a.path!r} in the left object has children {list(node_a.children.keys())} + # Node {node_b.path!r} in the right object has children {list(node_b.children.keys())}""" + # ) + # return diff - # Walking nodes in "level-order" fashion means walking down from the root breadth-first. - # Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree - # (which it is so long as children are stored in a tuple or list rather than in a set). - for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b), strict=True): + for node_a, node_b in zip(a.subtree, b.subtree, strict=True): path_a, path_b = node_a.path, node_b.path if require_names_equal and node_a.name != node_b.name: @@ -1009,7 +1020,7 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s ) return diff - return "" + return None def diff_dataset_repr(a, b, compat): @@ -1063,7 +1074,7 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat): # If the trees structures are different there is no point comparing each node # TODO we could show any differences in nodes up to the first place that structure differs? - if treestructure_diff or compat == "isomorphic": + if treestructure_diff is not None or compat == "isomorphic": summary.append("\n" + treestructure_diff) else: nodewise_diff = diff_nodewise_summary(a, b, compat) diff --git a/xarray/core/iterators.py b/xarray/core/iterators.py deleted file mode 100644 index eeaeb35aa9c..00000000000 --- a/xarray/core/iterators.py +++ /dev/null @@ -1,124 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable, Iterator - -from xarray.core.treenode import Tree - -"""These iterators are copied from anytree.iterators, with minor modifications.""" - - -class LevelOrderIter(Iterator): - """Iterate over tree applying level-order strategy starting at `node`. - This is the iterator used by `DataTree` to traverse nodes. - - Parameters - ---------- - node : Tree - Node in a tree to begin iteration at. - filter_ : Callable, optional - Function called with every `node` as argument, `node` is returned if `True`. - Default is to iterate through all ``node`` objects in the tree. - stop : Callable, optional - Function that will cause iteration to stop if ``stop`` returns ``True`` - for ``node``. - maxlevel : int, optional - Maximum level to descend in the node hierarchy. - - Examples - -------- - >>> from xarray.core.datatree import DataTree - >>> from xarray.core.iterators import LevelOrderIter - >>> f = DataTree.from_dict( - ... {"/b/a": None, "/b/d/c": None, "/b/d/e": None, "/g/h/i": None}, name="f" - ... ) - >>> print(f) - - Group: / - ├── Group: /b - │ ├── Group: /b/a - │ └── Group: /b/d - │ ├── Group: /b/d/c - │ └── Group: /b/d/e - └── Group: /g - └── Group: /g/h - └── Group: /g/h/i - >>> [node.name for node in LevelOrderIter(f)] - ['f', 'b', 'g', 'a', 'd', 'h', 'c', 'e', 'i'] - >>> [node.name for node in LevelOrderIter(f, maxlevel=3)] - ['f', 'b', 'g', 'a', 'd', 'h'] - >>> [ - ... node.name - ... for node in LevelOrderIter(f, filter_=lambda n: n.name not in ("e", "g")) - ... ] - ['f', 'b', 'a', 'd', 'h', 'c', 'i'] - >>> [node.name for node in LevelOrderIter(f, stop=lambda n: n.name == "d")] - ['f', 'b', 'g', 'a', 'h', 'i'] - """ - - def __init__( - self, - node: Tree, - filter_: Callable | None = None, - stop: Callable | None = None, - maxlevel: int | None = None, - ): - self.node = node - self.filter_ = filter_ - self.stop = stop - self.maxlevel = maxlevel - self.__iter = None - - def __init(self): - node = self.node - maxlevel = self.maxlevel - filter_ = self.filter_ or LevelOrderIter.__default_filter - stop = self.stop or LevelOrderIter.__default_stop - children = ( - [] - if LevelOrderIter._abort_at_level(1, maxlevel) - else LevelOrderIter._get_children([node], stop) - ) - return self._iter(children, filter_, stop, maxlevel) - - @staticmethod - def __default_filter(node: Tree) -> bool: - return True - - @staticmethod - def __default_stop(node: Tree) -> bool: - return False - - def __iter__(self) -> Iterator[Tree]: - return self - - def __next__(self) -> Iterator[Tree]: - if self.__iter is None: - self.__iter = self.__init() - item = next(self.__iter) # type: ignore[call-overload] - return item - - @staticmethod - def _abort_at_level(level: int, maxlevel: int | None) -> bool: - return maxlevel is not None and level > maxlevel - - @staticmethod - def _get_children(children: list[Tree], stop: Callable) -> list[Tree]: - return [child for child in children if not stop(child)] - - @staticmethod - def _iter( - children: list[Tree], filter_: Callable, stop: Callable, maxlevel: int | None - ) -> Iterator[Tree]: - level = 1 - while children: - next_children = [] - for child in children: - if filter_(child): - yield child - next_children += LevelOrderIter._get_children( - list(child.children.values()), stop - ) - children = next_children - level += 1 - if LevelOrderIter._abort_at_level(level, maxlevel): - break diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 604eb274aa9..a06e40b3688 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -1,5 +1,6 @@ from __future__ import annotations +import collections import sys from collections.abc import Iterator, Mapping from pathlib import PurePosixPath @@ -400,15 +401,18 @@ def subtree(self: Tree) -> Iterator[Tree]: """ An iterator over all nodes in this tree, including both self and all descendants. - Iterates depth-first. + Iterates bredth-first. See Also -------- DataTree.descendants """ - from xarray.core.iterators import LevelOrderIter - - return LevelOrderIter(self) + # https://en.wikipedia.org/wiki/Breadth-first_search#Pseudocode + queue = collections.deque([self]) + while queue: + node = queue.popleft() + yield node + queue.extend(node.children.values()) @property def descendants(self: Tree) -> tuple[Tree, ...]: @@ -773,3 +777,41 @@ def _path_to_ancestor(self, ancestor: NamedNode) -> NodePath: generation_gap = list(parents_paths).index(ancestor.path) path_upwards = "../" * generation_gap if generation_gap > 0 else "." return NodePath(path_upwards) + + +def zip_subtrees(*trees: AnyNamedNode) -> Iterator[tuple[AnyNamedNode, ...]]: + """Iterate over aligned subtrees in breadth-first order. + + Parameters: + ----------- + *trees : Tree + Trees to iterate over. + + Yields + ------ + Tuples of matching subtrees. + """ + # https://en.wikipedia.org/wiki/Breadth-first_search#Pseudocode + queue = collections.deque([trees]) + + while queue: + active_nodes = queue.popleft() + + # yield before raising an error, in case the caller chooses to exit + # iteration early + yield active_nodes + + first_node = active_nodes[0] + if any( + sibling.children.keys() != first_node.children.keys() + for sibling in active_nodes[1:] + ): + child_summary = " vs ".join( + str(list(node.children)) for node in active_nodes + ) + raise ValueError( + f"children at {first_node.path!r} do not match: {child_summary}" + ) + + for name in first_node.children: + queue.append(tuple(node.children[name] for node in active_nodes)) diff --git a/xarray/tests/test_treenode.py b/xarray/tests/test_treenode.py index 1db9c594247..4d7d594ac32 100644 --- a/xarray/tests/test_treenode.py +++ b/xarray/tests/test_treenode.py @@ -1,12 +1,16 @@ from __future__ import annotations -from collections.abc import Iterator -from typing import cast +import re import pytest -from xarray.core.iterators import LevelOrderIter -from xarray.core.treenode import InvalidTreeError, NamedNode, NodePath, TreeNode +from xarray.core.treenode import ( + InvalidTreeError, + NamedNode, + NodePath, + TreeNode, + zip_subtrees, +) class TestFamilyTree: @@ -299,15 +303,12 @@ def create_test_tree() -> tuple[NamedNode, NamedNode]: return a, f -class TestIterators: +class TestZipSubtrees: - def test_levelorderiter(self): + def test_one_tree(self): root, _ = create_test_tree() - result: list[str | None] = [ - node.name for node in cast(Iterator[NamedNode], LevelOrderIter(root)) - ] expected = [ - "a", # root Node is unnamed + "a", "b", "c", "d", @@ -317,8 +318,37 @@ def test_levelorderiter(self): "g", "i", ] + result = [node[0].name for node in zip_subtrees(root)] assert result == expected + def test_different_order(self): + first: NamedNode = NamedNode( + name="a", children={"b": NamedNode(), "c": NamedNode()} + ) + second: NamedNode = NamedNode( + name="a", children={"c": NamedNode(), "b": NamedNode()} + ) + assert [node.name for node in first.subtree] == ["a", "b", "c"] + assert [node.name for node in second.subtree] == ["a", "c", "b"] + assert [(x.name, y.name) for x, y in zip_subtrees(first, second)] == [ + ("a", "a"), + ("b", "b"), + ("c", "c"), + ] + + def test_different_structure(self): + first: NamedNode = NamedNode(name="a", children={"b": NamedNode()}) + second: NamedNode = NamedNode(name="a", children={"c": NamedNode()}) + it = zip_subtrees(first, second) + + x, y = next(it) + assert x.name == y.name == "a" + + with pytest.raises( + ValueError, match=re.escape(r"children at '/' do not match: ['b'] vs ['c']") + ): + next(it) + class TestAncestry: @@ -343,7 +373,6 @@ def test_ancestors(self): def test_subtree(self): root, _ = create_test_tree() - subtree = root.subtree expected = [ "a", "b", @@ -355,8 +384,8 @@ def test_subtree(self): "g", "i", ] - for node, expected_name in zip(subtree, expected, strict=True): - assert node.name == expected_name + actual = [node.name for node in root.subtree] + assert expected == actual def test_descendants(self): root, _ = create_test_tree()