From 6d62719061f7b9cbed3ea01c4adf2ce89c4f015f Mon Sep 17 00:00:00 2001 From: owenlittlejohns Date: Fri, 26 Apr 2024 11:29:18 -0600 Subject: [PATCH] Migrate datatreee assertions/extensions/formatting (#8967) * DAS-2067 - Migrate formatting.py. * DAS-2067 - Migrate datatree/extensions.py. * DAS-2067 - Migrate datatree/tests/test_dataset_api.py. * DAS-2067 - Migrate datatree_render.py. * DAS-2067 - Migrate DataTree assertions into xarray/testing/assertions.py. * DAS-2067 - Update doc/whats-new.rst. * DAS-2067 - Fix doctests for DataTreeRender.by_attr. * DAS-2067 - Fix comments in doctests examples for datatree_render. * DAS-2067 - Implement PR feedback, fix RenderDataTree.__str__. * DAS-2067 - Add overload for xarray.testing.assert_equal and xarray.testing.assert_identical. * DAS-2067 - Remove out-of-date comments. * Remove test of printing datatree --------- Co-authored-by: Tom Nicholas --- doc/whats-new.rst | 9 +- xarray/core/datatree.py | 6 +- xarray/core/datatree_mapping.py | 37 +-- xarray/core/datatree_render.py | 266 +++++++++++++++++ xarray/core/extensions.py | 18 ++ xarray/core/formatting.py | 115 ++++++++ xarray/datatree_/datatree/extensions.py | 20 -- xarray/datatree_/datatree/formatting.py | 91 ------ xarray/datatree_/datatree/ops.py | 2 +- xarray/datatree_/datatree/render.py | 271 ------------------ xarray/datatree_/datatree/testing.py | 120 -------- xarray/datatree_/datatree/tests/__init__.py | 29 -- xarray/datatree_/datatree/tests/conftest.py | 65 ----- .../datatree/tests/test_dataset_api.py | 98 ------- .../datatree/tests/test_extensions.py | 40 --- .../datatree/tests/test_formatting.py | 120 -------- xarray/testing/__init__.py | 1 + xarray/testing/assertions.py | 106 ++++++- xarray/tests/test_backends_datatree.py | 2 +- xarray/tests/test_datatree.py | 178 +++++++++--- xarray/tests/test_datatree_mapping.py | 2 +- xarray/tests/test_extensions.py | 9 + xarray/tests/test_formatting.py | 103 +++++++ 23 files changed, 763 insertions(+), 945 deletions(-) create mode 100644 xarray/core/datatree_render.py delete mode 100644 xarray/datatree_/datatree/extensions.py delete mode 100644 xarray/datatree_/datatree/formatting.py delete mode 100644 xarray/datatree_/datatree/render.py delete mode 100644 xarray/datatree_/datatree/testing.py delete mode 100644 xarray/datatree_/datatree/tests/__init__.py delete mode 100644 xarray/datatree_/datatree/tests/conftest.py delete mode 100644 xarray/datatree_/datatree/tests/test_dataset_api.py delete mode 100644 xarray/datatree_/datatree/tests/test_extensions.py delete mode 100644 xarray/datatree_/datatree/tests/test_formatting.py diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 413ff091d01..8b05b729a31 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -42,12 +42,17 @@ Bug fixes Internal Changes ~~~~~~~~~~~~~~~~ -- Migrates ``formatting_html`` functionality for `DataTree` into ``xarray/core`` (:pull: `8930`) +- Migrates ``formatting_html`` functionality for ``DataTree`` into ``xarray/core`` (:pull: `8930`) By `Eni Awowale `_, `Julia Signell `_ and `Tom Nicholas `_. - Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`) By `Matt Savoie `_ `Owen Littlejohns - ` and `Tom Nicholas `_. + `_ and `Tom Nicholas `_. +- Migrates ``extensions``, ``formatting`` and ``datatree_render`` functionality for + ``DataTree`` into ``xarray/core``. Also migrates ``testing`` functionality into + ``xarray/testing/assertions`` for ``DataTree``. (:pull:`8967`) + By `Owen Littlejohns `_ and + `Tom Nicholas `_. .. _whats-new.2024.03.0: diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 57fd7222898..48c714b697c 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -23,6 +23,8 @@ check_isomorphic, map_over_subtree, ) +from xarray.core.datatree_render import RenderDataTree +from xarray.core.formatting import datatree_repr from xarray.core.formatting_html import ( datatree_repr as datatree_repr_html, ) @@ -40,13 +42,11 @@ ) from xarray.core.variable import Variable from xarray.datatree_.datatree.common import TreeAttrAccessMixin -from xarray.datatree_.datatree.formatting import datatree_repr from xarray.datatree_.datatree.ops import ( DataTreeArithmeticMixin, MappedDatasetMethodsMixin, MappedDataWithCoords, ) -from xarray.datatree_.datatree.render import RenderTree try: from xarray.core.variable import calculate_dimensions @@ -1451,7 +1451,7 @@ def pipe( def render(self): """Print tree structure, including any data stored at each node.""" - for pre, fill, node in RenderTree(self): + for pre, fill, node in RenderDataTree(self): print(f"{pre}DataTree('{self.name}')") for ds_line in repr(node.ds)[1:]: print(f"{fill}{ds_line}") diff --git a/xarray/core/datatree_mapping.py b/xarray/core/datatree_mapping.py index 714921d2a90..4da934f2085 100644 --- a/xarray/core/datatree_mapping.py +++ b/xarray/core/datatree_mapping.py @@ -3,11 +3,11 @@ import functools import sys from itertools import repeat -from textwrap import dedent from typing import TYPE_CHECKING, Callable -from xarray import DataArray, Dataset -from xarray.core.iterators import LevelOrderIter +from xarray.core.dataarray import DataArray +from xarray.core.dataset import Dataset +from xarray.core.formatting import diff_treestructure from xarray.core.treenode import NodePath, TreeNode if TYPE_CHECKING: @@ -71,37 +71,6 @@ def check_isomorphic( raise TreeIsomorphismError("DataTree objects are not isomorphic:\n" + diff) -def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> str: - """ - Return a summary of why two trees are not isomorphic. - If they are isomorphic return an empty string. - """ - - # Walking nodes in "level-order" fashion means walking down from the root breadth-first. - # Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree - # (which it is so long as children are stored in a tuple or list rather than in a set). - for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)): - path_a, path_b = node_a.path, node_b.path - - if require_names_equal and node_a.name != node_b.name: - diff = dedent( - f"""\ - Node '{path_a}' in the left object has name '{node_a.name}' - Node '{path_b}' in the right object has name '{node_b.name}'""" - ) - return diff - - if len(node_a.children) != len(node_b.children): - diff = dedent( - f"""\ - Number of children on node '{path_a}' of the left object: {len(node_a.children)} - Number of children on node '{path_b}' of the right object: {len(node_b.children)}""" - ) - return diff - - return "" - - def map_over_subtree(func: Callable) -> Callable: """ Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees. diff --git a/xarray/core/datatree_render.py b/xarray/core/datatree_render.py new file mode 100644 index 00000000000..d069071495e --- /dev/null +++ b/xarray/core/datatree_render.py @@ -0,0 +1,266 @@ +""" +String Tree Rendering. Copied from anytree. + +Minor changes to `RenderDataTree` include accessing `children.values()`, and +type hints. + +""" + +from __future__ import annotations + +from collections import namedtuple +from collections.abc import Iterable, Iterator +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from xarray.core.datatree import DataTree + +Row = namedtuple("Row", ("pre", "fill", "node")) + + +class AbstractStyle: + def __init__(self, vertical: str, cont: str, end: str): + """ + Tree Render Style. + Args: + vertical: Sign for vertical line. + cont: Chars for a continued branch. + end: Chars for the last branch. + """ + super().__init__() + self.vertical = vertical + self.cont = cont + self.end = end + assert ( + len(cont) == len(vertical) == len(end) + ), f"'{vertical}', '{cont}' and '{end}' need to have equal length" + + @property + def empty(self) -> str: + """Empty string as placeholder.""" + return " " * len(self.end) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + +class ContStyle(AbstractStyle): + def __init__(self): + """ + Continued style, without gaps. + + >>> from xarray.core.datatree import DataTree + >>> from xarray.core.datatree_render import RenderDataTree + >>> root = DataTree(name="root") + >>> s0 = DataTree(name="sub0", parent=root) + >>> s0b = DataTree(name="sub0B", parent=s0) + >>> s0a = DataTree(name="sub0A", parent=s0) + >>> s1 = DataTree(name="sub1", parent=root) + >>> print(RenderDataTree(root)) + DataTree('root', parent=None) + ├── DataTree('sub0') + │ ├── DataTree('sub0B') + │ └── DataTree('sub0A') + └── DataTree('sub1') + """ + super().__init__("\u2502 ", "\u251c\u2500\u2500 ", "\u2514\u2500\u2500 ") + + +class RenderDataTree: + def __init__( + self, + node: DataTree, + style=ContStyle(), + childiter: type = list, + maxlevel: int | None = None, + ): + """ + Render tree starting at `node`. + Keyword Args: + style (AbstractStyle): Render Style. + childiter: Child iterator. Note, due to the use of node.children.values(), + Iterables that change the order of children cannot be used + (e.g., `reversed`). + maxlevel: Limit rendering to this depth. + :any:`RenderDataTree` is an iterator, returning a tuple with 3 items: + `pre` + tree prefix. + `fill` + filling for multiline entries. + `node` + :any:`NodeMixin` object. + It is up to the user to assemble these parts to a whole. + + Examples + -------- + + >>> from xarray import Dataset + >>> from xarray.core.datatree import DataTree + >>> from xarray.core.datatree_render import RenderDataTree + >>> root = DataTree(name="root", data=Dataset({"a": 0, "b": 1})) + >>> s0 = DataTree(name="sub0", parent=root, data=Dataset({"c": 2, "d": 3})) + >>> s0b = DataTree(name="sub0B", parent=s0, data=Dataset({"e": 4})) + >>> s0a = DataTree(name="sub0A", parent=s0, data=Dataset({"f": 5, "g": 6})) + >>> s1 = DataTree(name="sub1", parent=root, data=Dataset({"h": 7})) + + # Simple one line: + + >>> for pre, _, node in RenderDataTree(root): + ... print(f"{pre}{node.name}") + ... + root + ├── sub0 + │ ├── sub0B + │ └── sub0A + └── sub1 + + # Multiline: + + >>> for pre, fill, node in RenderDataTree(root): + ... print(f"{pre}{node.name}") + ... for variable in node.variables: + ... print(f"{fill}{variable}") + ... + root + a + b + ├── sub0 + │ c + │ d + │ ├── sub0B + │ │ e + │ └── sub0A + │ f + │ g + └── sub1 + h + + :any:`by_attr` simplifies attribute rendering and supports multiline: + >>> print(RenderDataTree(root).by_attr()) + root + ├── sub0 + │ ├── sub0B + │ └── sub0A + └── sub1 + + # `maxlevel` limits the depth of the tree: + + >>> print(RenderDataTree(root, maxlevel=2).by_attr("name")) + root + ├── sub0 + └── sub1 + """ + if not isinstance(style, AbstractStyle): + style = style() + self.node = node + self.style = style + self.childiter = childiter + self.maxlevel = maxlevel + + def __iter__(self) -> Iterator[Row]: + return self.__next(self.node, tuple()) + + def __next( + self, node: DataTree, continues: tuple[bool, ...], level: int = 0 + ) -> Iterator[Row]: + yield RenderDataTree.__item(node, continues, self.style) + children = node.children.values() + level += 1 + if children and (self.maxlevel is None or level < self.maxlevel): + children = self.childiter(children) + for child, is_last in _is_last(children): + yield from self.__next(child, continues + (not is_last,), level=level) + + @staticmethod + def __item( + node: DataTree, continues: tuple[bool, ...], style: AbstractStyle + ) -> Row: + if not continues: + return Row("", "", node) + else: + items = [style.vertical if cont else style.empty for cont in continues] + indent = "".join(items[:-1]) + branch = style.cont if continues[-1] else style.end + pre = indent + branch + fill = "".join(items) + return Row(pre, fill, node) + + def __str__(self) -> str: + return str(self.node) + + def __repr__(self) -> str: + classname = self.__class__.__name__ + args = [ + repr(self.node), + f"style={repr(self.style)}", + f"childiter={repr(self.childiter)}", + ] + return f"{classname}({', '.join(args)})" + + def by_attr(self, attrname: str = "name") -> str: + """ + Return rendered tree with node attribute `attrname`. + + Examples + -------- + + >>> from xarray import Dataset + >>> from xarray.core.datatree import DataTree + >>> from xarray.core.datatree_render import RenderDataTree + >>> root = DataTree(name="root") + >>> s0 = DataTree(name="sub0", parent=root) + >>> s0b = DataTree( + ... name="sub0B", parent=s0, data=Dataset({"foo": 4, "bar": 109}) + ... ) + >>> s0a = DataTree(name="sub0A", parent=s0) + >>> s1 = DataTree(name="sub1", parent=root) + >>> s1a = DataTree(name="sub1A", parent=s1) + >>> s1b = DataTree(name="sub1B", parent=s1, data=Dataset({"bar": 8})) + >>> s1c = DataTree(name="sub1C", parent=s1) + >>> s1ca = DataTree(name="sub1Ca", parent=s1c) + >>> print(RenderDataTree(root).by_attr("name")) + root + ├── sub0 + │ ├── sub0B + │ └── sub0A + └── sub1 + ├── sub1A + ├── sub1B + └── sub1C + └── sub1Ca + """ + + def get() -> Iterator[str]: + for pre, fill, node in self: + attr = ( + attrname(node) + if callable(attrname) + else getattr(node, attrname, "") + ) + if isinstance(attr, (list, tuple)): + lines = attr + else: + lines = str(attr).split("\n") + yield f"{pre}{lines[0]}" + for line in lines[1:]: + yield f"{fill}{line}" + + return "\n".join(get()) + + +def _is_last(iterable: Iterable) -> Iterator[tuple[DataTree, bool]]: + iter_ = iter(iterable) + try: + nextitem = next(iter_) + except StopIteration: + pass + else: + item = nextitem + while True: + try: + nextitem = next(iter_) + yield item, False + except StopIteration: + yield nextitem, True + break + item = nextitem diff --git a/xarray/core/extensions.py b/xarray/core/extensions.py index efe00718a79..9ebbd564f4f 100644 --- a/xarray/core/extensions.py +++ b/xarray/core/extensions.py @@ -4,6 +4,7 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset +from xarray.core.datatree import DataTree class AccessorRegistrationWarning(Warning): @@ -121,3 +122,20 @@ def register_dataset_accessor(name): register_dataarray_accessor """ return _register_accessor(name, Dataset) + + +def register_datatree_accessor(name): + """Register a custom accessor on DataTree objects. + + Parameters + ---------- + name : str + Name under which the accessor should be registered. A warning is issued + if this name conflicts with a preexisting attribute. + + See Also + -------- + xarray.register_dataarray_accessor + xarray.register_dataset_accessor + """ + return _register_accessor(name, DataTree) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 3eed7d02a2e..ad65a44d7d5 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -11,20 +11,24 @@ from datetime import datetime, timedelta from itertools import chain, zip_longest from reprlib import recursive_repr +from textwrap import dedent from typing import TYPE_CHECKING import numpy as np import pandas as pd from pandas.errors import OutOfBoundsDatetime +from xarray.core.datatree_render import RenderDataTree from xarray.core.duck_array_ops import array_equiv, astype from xarray.core.indexing import MemoryCachedArray +from xarray.core.iterators import LevelOrderIter from xarray.core.options import OPTIONS, _get_boolean_with_default from xarray.core.utils import is_duck_array from xarray.namedarray.pycompat import array_type, to_duck_array, to_numpy if TYPE_CHECKING: from xarray.core.coordinates import AbstractCoordinates + from xarray.core.datatree import DataTree UNITS = ("B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") @@ -926,6 +930,37 @@ def diff_array_repr(a, b, compat): return "\n".join(summary) +def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> str: + """ + Return a summary of why two trees are not isomorphic. + If they are isomorphic return an empty string. + """ + + # Walking nodes in "level-order" fashion means walking down from the root breadth-first. + # Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree + # (which it is so long as children are stored in a tuple or list rather than in a set). + for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)): + path_a, path_b = node_a.path, node_b.path + + if require_names_equal and node_a.name != node_b.name: + diff = dedent( + f"""\ + Node '{path_a}' in the left object has name '{node_a.name}' + Node '{path_b}' in the right object has name '{node_b.name}'""" + ) + return diff + + if len(node_a.children) != len(node_b.children): + diff = dedent( + f"""\ + Number of children on node '{path_a}' of the left object: {len(node_a.children)} + Number of children on node '{path_b}' of the right object: {len(node_b.children)}""" + ) + return diff + + return "" + + def diff_dataset_repr(a, b, compat): summary = [ f"Left and right {type(a).__name__} objects are not {_compat_to_str(compat)}" @@ -945,6 +980,86 @@ def diff_dataset_repr(a, b, compat): return "\n".join(summary) +def diff_nodewise_summary(a: DataTree, b: DataTree, compat): + """Iterates over all corresponding nodes, recording differences between data at each location.""" + + compat_str = _compat_to_str(compat) + + summary = [] + for node_a, node_b in zip(a.subtree, b.subtree): + a_ds, b_ds = node_a.ds, node_b.ds + + if not a_ds._all_compat(b_ds, compat): + dataset_diff = diff_dataset_repr(a_ds, b_ds, compat_str) + data_diff = "\n".join(dataset_diff.split("\n", 1)[1:]) + + nodediff = ( + f"\nData in nodes at position '{node_a.path}' do not match:" + f"{data_diff}" + ) + summary.append(nodediff) + + return "\n".join(summary) + + +def diff_datatree_repr(a: DataTree, b: DataTree, compat): + summary = [ + f"Left and right {type(a).__name__} objects are not {_compat_to_str(compat)}" + ] + + strict_names = True if compat in ["equals", "identical"] else False + treestructure_diff = diff_treestructure(a, b, strict_names) + + # If the trees structures are different there is no point comparing each node + # TODO we could show any differences in nodes up to the first place that structure differs? + if treestructure_diff or compat == "isomorphic": + summary.append("\n" + treestructure_diff) + else: + nodewise_diff = diff_nodewise_summary(a, b, compat) + summary.append("\n" + nodewise_diff) + + return "\n".join(summary) + + +def _single_node_repr(node: DataTree) -> str: + """Information about this node, not including its relationships to other nodes.""" + node_info = f"DataTree('{node.name}')" + + if node.has_data or node.has_attrs: + ds_info = "\n" + repr(node.ds) + else: + ds_info = "" + return node_info + ds_info + + +def datatree_repr(dt: DataTree): + """A printable representation of the structure of this entire tree.""" + renderer = RenderDataTree(dt) + + lines = [] + for pre, fill, node in renderer: + node_repr = _single_node_repr(node) + + node_line = f"{pre}{node_repr.splitlines()[0]}" + lines.append(node_line) + + if node.has_data or node.has_attrs: + ds_repr = node_repr.splitlines()[2:] + for line in ds_repr: + if len(node.children) > 0: + lines.append(f"{fill}{renderer.style.vertical}{line}") + else: + lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}") + + # Tack on info about whether or not root node has a parent at the start + first_line = lines[0] + parent = f'"{dt.parent.name}"' if dt.parent is not None else "None" + first_line_with_parent = first_line[:-1] + f", parent={parent})" + lines[0] = first_line_with_parent + + return "\n".join(lines) + + def shorten_list_repr(items: Sequence, max_items: int) -> str: if len(items) <= max_items: return repr(items) diff --git a/xarray/datatree_/datatree/extensions.py b/xarray/datatree_/datatree/extensions.py deleted file mode 100644 index bf888fc4484..00000000000 --- a/xarray/datatree_/datatree/extensions.py +++ /dev/null @@ -1,20 +0,0 @@ -from xarray.core.extensions import _register_accessor - -from xarray.core.datatree import DataTree - - -def register_datatree_accessor(name): - """Register a custom accessor on DataTree objects. - - Parameters - ---------- - name : str - Name under which the accessor should be registered. A warning is issued - if this name conflicts with a preexisting attribute. - - See Also - -------- - xarray.register_dataarray_accessor - xarray.register_dataset_accessor - """ - return _register_accessor(name, DataTree) diff --git a/xarray/datatree_/datatree/formatting.py b/xarray/datatree_/datatree/formatting.py deleted file mode 100644 index fdd23933ae6..00000000000 --- a/xarray/datatree_/datatree/formatting.py +++ /dev/null @@ -1,91 +0,0 @@ -from typing import TYPE_CHECKING - -from xarray.core.formatting import _compat_to_str, diff_dataset_repr - -from xarray.core.datatree_mapping import diff_treestructure -from xarray.datatree_.datatree.render import RenderTree - -if TYPE_CHECKING: - from xarray.core.datatree import DataTree - - -def diff_nodewise_summary(a, b, compat): - """Iterates over all corresponding nodes, recording differences between data at each location.""" - - compat_str = _compat_to_str(compat) - - summary = [] - for node_a, node_b in zip(a.subtree, b.subtree): - a_ds, b_ds = node_a.ds, node_b.ds - - if not a_ds._all_compat(b_ds, compat): - dataset_diff = diff_dataset_repr(a_ds, b_ds, compat_str) - data_diff = "\n".join(dataset_diff.split("\n", 1)[1:]) - - nodediff = ( - f"\nData in nodes at position '{node_a.path}' do not match:" - f"{data_diff}" - ) - summary.append(nodediff) - - return "\n".join(summary) - - -def diff_tree_repr(a, b, compat): - summary = [ - f"Left and right {type(a).__name__} objects are not {_compat_to_str(compat)}" - ] - - # TODO check root parents? - - strict_names = True if compat in ["equals", "identical"] else False - treestructure_diff = diff_treestructure(a, b, strict_names) - - # If the trees structures are different there is no point comparing each node - # TODO we could show any differences in nodes up to the first place that structure differs? - if treestructure_diff or compat == "isomorphic": - summary.append("\n" + treestructure_diff) - else: - nodewise_diff = diff_nodewise_summary(a, b, compat) - summary.append("\n" + nodewise_diff) - - return "\n".join(summary) - - -def datatree_repr(dt): - """A printable representation of the structure of this entire tree.""" - renderer = RenderTree(dt) - - lines = [] - for pre, fill, node in renderer: - node_repr = _single_node_repr(node) - - node_line = f"{pre}{node_repr.splitlines()[0]}" - lines.append(node_line) - - if node.has_data or node.has_attrs: - ds_repr = node_repr.splitlines()[2:] - for line in ds_repr: - if len(node.children) > 0: - lines.append(f"{fill}{renderer.style.vertical}{line}") - else: - lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}") - - # Tack on info about whether or not root node has a parent at the start - first_line = lines[0] - parent = f'"{dt.parent.name}"' if dt.parent is not None else "None" - first_line_with_parent = first_line[:-1] + f", parent={parent})" - lines[0] = first_line_with_parent - - return "\n".join(lines) - - -def _single_node_repr(node: "DataTree") -> str: - """Information about this node, not including its relationships to other nodes.""" - node_info = f"DataTree('{node.name}')" - - if node.has_data or node.has_attrs: - ds_info = "\n" + repr(node.ds) - else: - ds_info = "" - return node_info + ds_info diff --git a/xarray/datatree_/datatree/ops.py b/xarray/datatree_/datatree/ops.py index 83b9d1b275a..1ca8a7c1e01 100644 --- a/xarray/datatree_/datatree/ops.py +++ b/xarray/datatree_/datatree/ops.py @@ -1,6 +1,6 @@ import textwrap -from xarray import Dataset +from xarray.core.dataset import Dataset from xarray.core.datatree_mapping import map_over_subtree diff --git a/xarray/datatree_/datatree/render.py b/xarray/datatree_/datatree/render.py deleted file mode 100644 index e6af9c85ee8..00000000000 --- a/xarray/datatree_/datatree/render.py +++ /dev/null @@ -1,271 +0,0 @@ -""" -String Tree Rendering. Copied from anytree. -""" - -import collections -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from xarray.core.datatree import DataTree - -Row = collections.namedtuple("Row", ("pre", "fill", "node")) - - -class AbstractStyle(object): - def __init__(self, vertical, cont, end): - """ - Tree Render Style. - Args: - vertical: Sign for vertical line. - cont: Chars for a continued branch. - end: Chars for the last branch. - """ - super(AbstractStyle, self).__init__() - self.vertical = vertical - self.cont = cont - self.end = end - assert ( - len(cont) == len(vertical) == len(end) - ), f"'{vertical}', '{cont}' and '{end}' need to have equal length" - - @property - def empty(self): - """Empty string as placeholder.""" - return " " * len(self.end) - - def __repr__(self): - return f"{self.__class__.__name__}()" - - -class ContStyle(AbstractStyle): - def __init__(self): - """ - Continued style, without gaps. - - >>> from anytree import Node, RenderTree - >>> root = Node("root") - >>> s0 = Node("sub0", parent=root) - >>> s0b = Node("sub0B", parent=s0) - >>> s0a = Node("sub0A", parent=s0) - >>> s1 = Node("sub1", parent=root) - >>> print(RenderTree(root, style=ContStyle())) - - Node('/root') - ├── Node('/root/sub0') - │ ├── Node('/root/sub0/sub0B') - │ └── Node('/root/sub0/sub0A') - └── Node('/root/sub1') - """ - super(ContStyle, self).__init__( - "\u2502 ", "\u251c\u2500\u2500 ", "\u2514\u2500\u2500 " - ) - - -class RenderTree(object): - def __init__( - self, node: "DataTree", style=ContStyle(), childiter=list, maxlevel=None - ): - """ - Render tree starting at `node`. - Keyword Args: - style (AbstractStyle): Render Style. - childiter: Child iterator. - maxlevel: Limit rendering to this depth. - :any:`RenderTree` is an iterator, returning a tuple with 3 items: - `pre` - tree prefix. - `fill` - filling for multiline entries. - `node` - :any:`NodeMixin` object. - It is up to the user to assemble these parts to a whole. - >>> from anytree import Node, RenderTree - >>> root = Node("root", lines=["c0fe", "c0de"]) - >>> s0 = Node("sub0", parent=root, lines=["ha", "ba"]) - >>> s0b = Node("sub0B", parent=s0, lines=["1", "2", "3"]) - >>> s0a = Node("sub0A", parent=s0, lines=["a", "b"]) - >>> s1 = Node("sub1", parent=root, lines=["Z"]) - Simple one line: - >>> for pre, _, node in RenderTree(root): - ... print("%s%s" % (pre, node.name)) - ... - root - ├── sub0 - │ ├── sub0B - │ └── sub0A - └── sub1 - Multiline: - >>> for pre, fill, node in RenderTree(root): - ... print("%s%s" % (pre, node.lines[0])) - ... for line in node.lines[1:]: - ... print("%s%s" % (fill, line)) - ... - c0fe - c0de - ├── ha - │ ba - │ ├── 1 - │ │ 2 - │ │ 3 - │ └── a - │ b - └── Z - `maxlevel` limits the depth of the tree: - >>> print(RenderTree(root, maxlevel=2)) - Node('/root', lines=['c0fe', 'c0de']) - ├── Node('/root/sub0', lines=['ha', 'ba']) - └── Node('/root/sub1', lines=['Z']) - The `childiter` is responsible for iterating over child nodes at the - same level. An reversed order can be achived by using `reversed`. - >>> for row in RenderTree(root, childiter=reversed): - ... print("%s%s" % (row.pre, row.node.name)) - ... - root - ├── sub1 - └── sub0 - ├── sub0A - └── sub0B - Or writing your own sort function: - >>> def mysort(items): - ... return sorted(items, key=lambda item: item.name) - ... - >>> for row in RenderTree(root, childiter=mysort): - ... print("%s%s" % (row.pre, row.node.name)) - ... - root - ├── sub0 - │ ├── sub0A - │ └── sub0B - └── sub1 - :any:`by_attr` simplifies attribute rendering and supports multiline: - >>> print(RenderTree(root).by_attr()) - root - ├── sub0 - │ ├── sub0B - │ └── sub0A - └── sub1 - >>> print(RenderTree(root).by_attr("lines")) - c0fe - c0de - ├── ha - │ ba - │ ├── 1 - │ │ 2 - │ │ 3 - │ └── a - │ b - └── Z - And can be a function: - >>> print(RenderTree(root).by_attr(lambda n: " ".join(n.lines))) - c0fe c0de - ├── ha ba - │ ├── 1 2 3 - │ └── a b - └── Z - """ - if not isinstance(style, AbstractStyle): - style = style() - self.node = node - self.style = style - self.childiter = childiter - self.maxlevel = maxlevel - - def __iter__(self): - return self.__next(self.node, tuple()) - - def __next(self, node, continues, level=0): - yield RenderTree.__item(node, continues, self.style) - children = node.children.values() - level += 1 - if children and (self.maxlevel is None or level < self.maxlevel): - children = self.childiter(children) - for child, is_last in _is_last(children): - for grandchild in self.__next( - child, continues + (not is_last,), level=level - ): - yield grandchild - - @staticmethod - def __item(node, continues, style): - if not continues: - return Row("", "", node) - else: - items = [style.vertical if cont else style.empty for cont in continues] - indent = "".join(items[:-1]) - branch = style.cont if continues[-1] else style.end - pre = indent + branch - fill = "".join(items) - return Row(pre, fill, node) - - def __str__(self): - lines = ["%s%r" % (pre, node) for pre, _, node in self] - return "\n".join(lines) - - def __repr__(self): - classname = self.__class__.__name__ - args = [ - repr(self.node), - "style=%s" % repr(self.style), - "childiter=%s" % repr(self.childiter), - ] - return "%s(%s)" % (classname, ", ".join(args)) - - def by_attr(self, attrname="name"): - """ - Return rendered tree with node attribute `attrname`. - >>> from anytree import AnyNode, RenderTree - >>> root = AnyNode(id="root") - >>> s0 = AnyNode(id="sub0", parent=root) - >>> s0b = AnyNode(id="sub0B", parent=s0, foo=4, bar=109) - >>> s0a = AnyNode(id="sub0A", parent=s0) - >>> s1 = AnyNode(id="sub1", parent=root) - >>> s1a = AnyNode(id="sub1A", parent=s1) - >>> s1b = AnyNode(id="sub1B", parent=s1, bar=8) - >>> s1c = AnyNode(id="sub1C", parent=s1) - >>> s1ca = AnyNode(id="sub1Ca", parent=s1c) - >>> print(RenderTree(root).by_attr("id")) - root - ├── sub0 - │ ├── sub0B - │ └── sub0A - └── sub1 - ├── sub1A - ├── sub1B - └── sub1C - └── sub1Ca - """ - - def get(): - for pre, fill, node in self: - attr = ( - attrname(node) - if callable(attrname) - else getattr(node, attrname, "") - ) - if isinstance(attr, (list, tuple)): - lines = attr - else: - lines = str(attr).split("\n") - yield "%s%s" % (pre, lines[0]) - for line in lines[1:]: - yield "%s%s" % (fill, line) - - return "\n".join(get()) - - -def _is_last(iterable): - iter_ = iter(iterable) - try: - nextitem = next(iter_) - except StopIteration: - pass - else: - item = nextitem - while True: - try: - nextitem = next(iter_) - yield item, False - except StopIteration: - yield nextitem, True - break - item = nextitem diff --git a/xarray/datatree_/datatree/testing.py b/xarray/datatree_/datatree/testing.py deleted file mode 100644 index bf54116725a..00000000000 --- a/xarray/datatree_/datatree/testing.py +++ /dev/null @@ -1,120 +0,0 @@ -from xarray.testing.assertions import ensure_warnings - -from xarray.core.datatree import DataTree -from .formatting import diff_tree_repr - - -@ensure_warnings -def assert_isomorphic(a: DataTree, b: DataTree, from_root: bool = False): - """ - Two DataTrees are considered isomorphic if every node has the same number of children. - - Nothing about the data in each node is checked. - - Isomorphism is a necessary condition for two trees to be used in a nodewise binary operation, - such as tree1 + tree2. - - By default this function does not check any part of the tree above the given node. - Therefore this function can be used as default to check that two subtrees are isomorphic. - - Parameters - ---------- - a : DataTree - The first object to compare. - b : DataTree - The second object to compare. - from_root : bool, optional, default is False - Whether or not to first traverse to the root of the trees before checking for isomorphism. - If a & b have no parents then this has no effect. - - See Also - -------- - DataTree.isomorphic - assert_equals - assert_identical - """ - __tracebackhide__ = True - assert isinstance(a, type(b)) - - if isinstance(a, DataTree): - if from_root: - a = a.root - b = b.root - - assert a.isomorphic(b, from_root=from_root), diff_tree_repr(a, b, "isomorphic") - else: - raise TypeError(f"{type(a)} not of type DataTree") - - -@ensure_warnings -def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): - """ - Two DataTrees are equal if they have isomorphic node structures, with matching node names, - and if they have matching variables and coordinates, all of which are equal. - - By default this method will check the whole tree above the given node. - - Parameters - ---------- - a : DataTree - The first object to compare. - b : DataTree - The second object to compare. - from_root : bool, optional, default is True - Whether or not to first traverse to the root of the trees before checking for isomorphism. - If a & b have no parents then this has no effect. - - See Also - -------- - DataTree.equals - assert_isomorphic - assert_identical - """ - __tracebackhide__ = True - assert isinstance(a, type(b)) - - if isinstance(a, DataTree): - if from_root: - a = a.root - b = b.root - - assert a.equals(b, from_root=from_root), diff_tree_repr(a, b, "equals") - else: - raise TypeError(f"{type(a)} not of type DataTree") - - -@ensure_warnings -def assert_identical(a: DataTree, b: DataTree, from_root: bool = True): - """ - Like assert_equals, but will also check all dataset attributes and the attributes on - all variables and coordinates. - - By default this method will check the whole tree above the given node. - - Parameters - ---------- - a : xarray.DataTree - The first object to compare. - b : xarray.DataTree - The second object to compare. - from_root : bool, optional, default is True - Whether or not to first traverse to the root of the trees before checking for isomorphism. - If a & b have no parents then this has no effect. - - See Also - -------- - DataTree.identical - assert_isomorphic - assert_equal - """ - - __tracebackhide__ = True - assert isinstance(a, type(b)) - if isinstance(a, DataTree): - if from_root: - a = a.root - b = b.root - - assert a.identical(b, from_root=from_root), diff_tree_repr(a, b, "identical") - else: - raise TypeError(f"{type(a)} not of type DataTree") diff --git a/xarray/datatree_/datatree/tests/__init__.py b/xarray/datatree_/datatree/tests/__init__.py deleted file mode 100644 index 64961158b13..00000000000 --- a/xarray/datatree_/datatree/tests/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -import importlib - -import pytest -from packaging import version - - -def _importorskip(modname, minversion=None): - try: - mod = importlib.import_module(modname) - has = True - if minversion is not None: - if LooseVersion(mod.__version__) < LooseVersion(minversion): - raise ImportError("Minimum version not satisfied") - except ImportError: - has = False - func = pytest.mark.skipif(not has, reason=f"requires {modname}") - return has, func - - -def LooseVersion(vstring): - # Our development version is something like '0.10.9+aac7bfc' - # This function just ignores the git commit id. - vstring = vstring.split("+")[0] - return version.parse(vstring) - - -has_zarr, requires_zarr = _importorskip("zarr") -has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf") -has_netCDF4, requires_netCDF4 = _importorskip("netCDF4") diff --git a/xarray/datatree_/datatree/tests/conftest.py b/xarray/datatree_/datatree/tests/conftest.py deleted file mode 100644 index 53a9a72239d..00000000000 --- a/xarray/datatree_/datatree/tests/conftest.py +++ /dev/null @@ -1,65 +0,0 @@ -import pytest -import xarray as xr - -from xarray.core.datatree import DataTree - - -@pytest.fixture(scope="module") -def create_test_datatree(): - """ - Create a test datatree with this structure: - - - |-- set1 - | |-- - | | Dimensions: () - | | Data variables: - | | a int64 0 - | | b int64 1 - | |-- set1 - | |-- set2 - |-- set2 - | |-- - | | Dimensions: (x: 2) - | | Data variables: - | | a (x) int64 2, 3 - | | b (x) int64 0.1, 0.2 - | |-- set1 - |-- set3 - |-- - | Dimensions: (x: 2, y: 3) - | Data variables: - | a (y) int64 6, 7, 8 - | set0 (x) int64 9, 10 - - The structure has deliberately repeated names of tags, variables, and - dimensions in order to better check for bugs caused by name conflicts. - """ - - def _create_test_datatree(modify=lambda ds: ds): - set1_data = modify(xr.Dataset({"a": 0, "b": 1})) - set2_data = modify(xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])})) - root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})) - - # Avoid using __init__ so we can independently test it - root = DataTree(data=root_data) - set1 = DataTree(name="set1", parent=root, data=set1_data) - DataTree(name="set1", parent=set1) - DataTree(name="set2", parent=set1) - set2 = DataTree(name="set2", parent=root, data=set2_data) - DataTree(name="set1", parent=set2) - DataTree(name="set3", parent=root) - - return root - - return _create_test_datatree - - -@pytest.fixture(scope="module") -def simple_datatree(create_test_datatree): - """ - Invoke create_test_datatree fixture (callback). - - Returns a DataTree. - """ - return create_test_datatree() diff --git a/xarray/datatree_/datatree/tests/test_dataset_api.py b/xarray/datatree_/datatree/tests/test_dataset_api.py deleted file mode 100644 index 4ca532ebba4..00000000000 --- a/xarray/datatree_/datatree/tests/test_dataset_api.py +++ /dev/null @@ -1,98 +0,0 @@ -import numpy as np -import xarray as xr - -from xarray.core.datatree import DataTree -from xarray.datatree_.datatree.testing import assert_equal - - -class TestDSMethodInheritance: - def test_dataset_method(self): - ds = xr.Dataset({"a": ("x", [1, 2, 3])}) - dt = DataTree(data=ds) - DataTree(name="results", parent=dt, data=ds) - - expected = DataTree(data=ds.isel(x=1)) - DataTree(name="results", parent=expected, data=ds.isel(x=1)) - - result = dt.isel(x=1) - assert_equal(result, expected) - - def test_reduce_method(self): - ds = xr.Dataset({"a": ("x", [False, True, False])}) - dt = DataTree(data=ds) - DataTree(name="results", parent=dt, data=ds) - - expected = DataTree(data=ds.any()) - DataTree(name="results", parent=expected, data=ds.any()) - - result = dt.any() - assert_equal(result, expected) - - def test_nan_reduce_method(self): - ds = xr.Dataset({"a": ("x", [1, 2, 3])}) - dt = DataTree(data=ds) - DataTree(name="results", parent=dt, data=ds) - - expected = DataTree(data=ds.mean()) - DataTree(name="results", parent=expected, data=ds.mean()) - - result = dt.mean() - assert_equal(result, expected) - - def test_cum_method(self): - ds = xr.Dataset({"a": ("x", [1, 2, 3])}) - dt = DataTree(data=ds) - DataTree(name="results", parent=dt, data=ds) - - expected = DataTree(data=ds.cumsum()) - DataTree(name="results", parent=expected, data=ds.cumsum()) - - result = dt.cumsum() - assert_equal(result, expected) - - -class TestOps: - def test_binary_op_on_int(self): - ds1 = xr.Dataset({"a": [5], "b": [3]}) - ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) - dt = DataTree(data=ds1) - DataTree(name="subnode", data=ds2, parent=dt) - - expected = DataTree(data=ds1 * 5) - DataTree(name="subnode", data=ds2 * 5, parent=expected) - - result = dt * 5 - assert_equal(result, expected) - - def test_binary_op_on_dataset(self): - ds1 = xr.Dataset({"a": [5], "b": [3]}) - ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) - dt = DataTree(data=ds1) - DataTree(name="subnode", data=ds2, parent=dt) - other_ds = xr.Dataset({"z": ("z", [0.1, 0.2])}) - - expected = DataTree(data=ds1 * other_ds) - DataTree(name="subnode", data=ds2 * other_ds, parent=expected) - - result = dt * other_ds - assert_equal(result, expected) - - def test_binary_op_on_datatree(self): - ds1 = xr.Dataset({"a": [5], "b": [3]}) - ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) - dt = DataTree(data=ds1) - DataTree(name="subnode", data=ds2, parent=dt) - - expected = DataTree(data=ds1 * ds1) - DataTree(name="subnode", data=ds2 * ds2, parent=expected) - - result = dt * dt - assert_equal(result, expected) - - -class TestUFuncs: - def test_tree(self, create_test_datatree): - dt = create_test_datatree() - expected = create_test_datatree(modify=lambda ds: np.sin(ds)) - result_tree = np.sin(dt) - assert_equal(result_tree, expected) diff --git a/xarray/datatree_/datatree/tests/test_extensions.py b/xarray/datatree_/datatree/tests/test_extensions.py deleted file mode 100644 index 329696c7a9d..00000000000 --- a/xarray/datatree_/datatree/tests/test_extensions.py +++ /dev/null @@ -1,40 +0,0 @@ -import pytest - -from xarray.core.datatree import DataTree -from xarray.datatree_.datatree.extensions import register_datatree_accessor - - -class TestAccessor: - def test_register(self) -> None: - @register_datatree_accessor("demo") - class DemoAccessor: - """Demo accessor.""" - - def __init__(self, xarray_obj): - self._obj = xarray_obj - - @property - def foo(self): - return "bar" - - dt: DataTree = DataTree() - assert dt.demo.foo == "bar" - - # accessor is cached - assert dt.demo is dt.demo - - # check descriptor - assert dt.demo.__doc__ == "Demo accessor." - assert DataTree.demo.__doc__ == "Demo accessor." # type: ignore - assert isinstance(dt.demo, DemoAccessor) - assert DataTree.demo is DemoAccessor # type: ignore - - with pytest.warns(Warning, match="overriding a preexisting attribute"): - - @register_datatree_accessor("demo") - class Foo: - pass - - # ensure we can remove it - del DataTree.demo # type: ignore - assert not hasattr(DataTree, "demo") diff --git a/xarray/datatree_/datatree/tests/test_formatting.py b/xarray/datatree_/datatree/tests/test_formatting.py deleted file mode 100644 index 77f8346ae72..00000000000 --- a/xarray/datatree_/datatree/tests/test_formatting.py +++ /dev/null @@ -1,120 +0,0 @@ -from textwrap import dedent - -from xarray import Dataset - -from xarray.core.datatree import DataTree -from xarray.datatree_.datatree.formatting import diff_tree_repr - - -class TestRepr: - def test_print_empty_node(self): - dt = DataTree(name="root") - printout = dt.__str__() - assert printout == "DataTree('root', parent=None)" - - def test_print_empty_node_with_attrs(self): - dat = Dataset(attrs={"note": "has attrs"}) - dt = DataTree(name="root", data=dat) - printout = dt.__str__() - assert printout == dedent( - """\ - DataTree('root', parent=None) - Dimensions: () - Data variables: - *empty* - Attributes: - note: has attrs""" - ) - - def test_print_node_with_data(self): - dat = Dataset({"a": [0, 2]}) - dt = DataTree(name="root", data=dat) - printout = dt.__str__() - expected = [ - "DataTree('root', parent=None)", - "Dimensions", - "Coordinates", - "a", - "Data variables", - "*empty*", - ] - for expected_line, printed_line in zip(expected, printout.splitlines()): - assert expected_line in printed_line - - def test_nested_node(self): - dat = Dataset({"a": [0, 2]}) - root = DataTree(name="root") - DataTree(name="results", data=dat, parent=root) - printout = root.__str__() - assert printout.splitlines()[2].startswith(" ") - - def test_print_datatree(self, simple_datatree): - dt = simple_datatree - print(dt) - - # TODO work out how to test something complex like this - - def test_repr_of_node_with_data(self): - dat = Dataset({"a": [0, 2]}) - dt = DataTree(name="root", data=dat) - assert "Coordinates" in repr(dt) - - -class TestDiffFormatting: - def test_diff_structure(self): - dt_1 = DataTree.from_dict({"a": None, "a/b": None, "a/c": None}) - dt_2 = DataTree.from_dict({"d": None, "d/e": None}) - - expected = dedent( - """\ - Left and right DataTree objects are not isomorphic - - Number of children on node '/a' of the left object: 2 - Number of children on node '/d' of the right object: 1""" - ) - actual = diff_tree_repr(dt_1, dt_2, "isomorphic") - assert actual == expected - - def test_diff_node_names(self): - dt_1 = DataTree.from_dict({"a": None}) - dt_2 = DataTree.from_dict({"b": None}) - - expected = dedent( - """\ - Left and right DataTree objects are not identical - - Node '/a' in the left object has name 'a' - Node '/b' in the right object has name 'b'""" - ) - actual = diff_tree_repr(dt_1, dt_2, "identical") - assert actual == expected - - def test_diff_node_data(self): - import numpy as np - - # casting to int64 explicitly ensures that int64s are created on all architectures - ds1 = Dataset({"u": np.int64(0), "v": np.int64(1)}) - ds3 = Dataset({"w": np.int64(5)}) - dt_1 = DataTree.from_dict({"a": ds1, "a/b": ds3}) - ds2 = Dataset({"u": np.int64(0)}) - ds4 = Dataset({"w": np.int64(6)}) - dt_2 = DataTree.from_dict({"a": ds2, "a/b": ds4}) - - expected = dedent( - """\ - Left and right DataTree objects are not equal - - - Data in nodes at position '/a' do not match: - - Data variables only on the left object: - v int64 8B 1 - - Data in nodes at position '/a/b' do not match: - - Differing data variables: - L w int64 8B 5 - R w int64 8B 6""" - ) - actual = diff_tree_repr(dt_1, dt_2, "equals") - assert actual == expected diff --git a/xarray/testing/__init__.py b/xarray/testing/__init__.py index ab2f8ba4357..316b0ea5252 100644 --- a/xarray/testing/__init__.py +++ b/xarray/testing/__init__.py @@ -1,3 +1,4 @@ +# TODO: Add assert_isomorphic when making DataTree API public from xarray.testing.assertions import ( # noqa: F401 _assert_dataarray_invariants, _assert_dataset_invariants, diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index 6418eb79b8b..018874c169e 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -3,7 +3,7 @@ import functools import warnings from collections.abc import Hashable -from typing import Union +from typing import Union, overload import numpy as np import pandas as pd @@ -12,6 +12,8 @@ from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset +from xarray.core.datatree import DataTree +from xarray.core.formatting import diff_datatree_repr from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex, default_indexes from xarray.core.variable import IndexVariable, Variable @@ -50,7 +52,59 @@ def _data_allclose_or_equiv(arr1, arr2, rtol=1e-05, atol=1e-08, decode_bytes=Tru @ensure_warnings -def assert_equal(a, b): +def assert_isomorphic(a: DataTree, b: DataTree, from_root: bool = False): + """ + Two DataTrees are considered isomorphic if every node has the same number of children. + + Nothing about the data or attrs in each node is checked. + + Isomorphism is a necessary condition for two trees to be used in a nodewise binary operation, + such as tree1 + tree2. + + By default this function does not check any part of the tree above the given node. + Therefore this function can be used as default to check that two subtrees are isomorphic. + + Parameters + ---------- + a : DataTree + The first object to compare. + b : DataTree + The second object to compare. + from_root : bool, optional, default is False + Whether or not to first traverse to the root of the trees before checking for isomorphism. + If a & b have no parents then this has no effect. + + See Also + -------- + DataTree.isomorphic + assert_equal + assert_identical + """ + __tracebackhide__ = True + assert isinstance(a, type(b)) + + if isinstance(a, DataTree): + if from_root: + a = a.root + b = b.root + + assert a.isomorphic(b, from_root=from_root), diff_datatree_repr( + a, b, "isomorphic" + ) + else: + raise TypeError(f"{type(a)} not of type DataTree") + + +@overload +def assert_equal(a, b): ... + + +@overload +def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): ... + + +@ensure_warnings +def assert_equal(a, b, from_root=True): """Like :py:func:`numpy.testing.assert_array_equal`, but for xarray objects. @@ -59,12 +113,20 @@ def assert_equal(a, b): (except for Dataset objects for which the variable names must match). Arrays with NaN in the same location are considered equal. + For DataTree objects, assert_equal is mapped over all Datasets on each node, + with the DataTrees being equal if both are isomorphic and the corresponding + Datasets at each node are themselves equal. + Parameters ---------- - a : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates - The first object to compare. - b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates - The second object to compare. + a : xarray.Dataset, xarray.DataArray, xarray.Variable, xarray.Coordinates + or xarray.core.datatree.DataTree. The first object to compare. + b : xarray.Dataset, xarray.DataArray, xarray.Variable, xarray.Coordinates + or xarray.core.datatree.DataTree. The second object to compare. + from_root : bool, optional, default is True + Only used when comparing DataTree objects. Indicates whether or not to + first traverse to the root of the trees before checking for isomorphism. + If a & b have no parents then this has no effect. See Also -------- @@ -81,23 +143,45 @@ def assert_equal(a, b): assert a.equals(b), formatting.diff_dataset_repr(a, b, "equals") elif isinstance(a, Coordinates): assert a.equals(b), formatting.diff_coords_repr(a, b, "equals") + elif isinstance(a, DataTree): + if from_root: + a = a.root + b = b.root + + assert a.equals(b, from_root=from_root), diff_datatree_repr(a, b, "equals") else: raise TypeError(f"{type(a)} not supported by assertion comparison") +@overload +def assert_identical(a, b): ... + + +@overload +def assert_identical(a: DataTree, b: DataTree, from_root: bool = True): ... + + @ensure_warnings -def assert_identical(a, b): +def assert_identical(a, b, from_root=True): """Like :py:func:`xarray.testing.assert_equal`, but also matches the objects' names and attributes. Raises an AssertionError if two objects are not identical. + For DataTree objects, assert_identical is mapped over all Datasets on each + node, with the DataTrees being identical if both are isomorphic and the + corresponding Datasets at each node are themselves identical. + Parameters ---------- a : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates The first object to compare. b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates The second object to compare. + from_root : bool, optional, default is True + Only used when comparing DataTree objects. Indicates whether or not to + first traverse to the root of the trees before checking for isomorphism. + If a & b have no parents then this has no effect. See Also -------- @@ -116,6 +200,14 @@ def assert_identical(a, b): assert a.identical(b), formatting.diff_dataset_repr(a, b, "identical") elif isinstance(a, Coordinates): assert a.identical(b), formatting.diff_coords_repr(a, b, "identical") + elif isinstance(a, DataTree): + if from_root: + a = a.root + b = b.root + + assert a.identical(b, from_root=from_root), diff_datatree_repr( + a, b, "identical" + ) else: raise TypeError(f"{type(a)} not supported by assertion comparison") diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 7bdb2b532d9..4e819eec0b5 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -5,7 +5,7 @@ import pytest from xarray.backends.api import open_datatree -from xarray.datatree_.datatree.testing import assert_equal +from xarray.testing import assert_equal from xarray.tests import ( requires_h5netcdf, requires_netCDF4, diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 3de2fa62dc6..e667c8670c7 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -4,10 +4,9 @@ import pytest import xarray as xr -import xarray.datatree_.datatree.testing as dtt -import xarray.testing as xrt from xarray.core.datatree import DataTree from xarray.core.treenode import NotFoundInTreeError +from xarray.testing import assert_equal, assert_identical from xarray.tests import create_test_data, source_ndarray @@ -17,7 +16,7 @@ def test_empty(self): assert dt.name == "root" assert dt.parent is None assert dt.children == {} - xrt.assert_identical(dt.to_dataset(), xr.Dataset()) + assert_identical(dt.to_dataset(), xr.Dataset()) def test_unnamed(self): dt: DataTree = DataTree() @@ -115,7 +114,7 @@ def test_create_with_data(self): dat = xr.Dataset({"a": 0}) john: DataTree = DataTree(name="john", data=dat) - xrt.assert_identical(john.to_dataset(), dat) + assert_identical(john.to_dataset(), dat) with pytest.raises(TypeError): DataTree(name="mary", parent=john, data="junk") # type: ignore[arg-type] @@ -125,7 +124,7 @@ def test_set_data(self): dat = xr.Dataset({"a": 0}) john.ds = dat # type: ignore[assignment] - xrt.assert_identical(john.to_dataset(), dat) + assert_identical(john.to_dataset(), dat) with pytest.raises(TypeError): john.ds = "junk" # type: ignore[assignment] @@ -185,14 +184,14 @@ def test_getitem_self(self): def test_getitem_single_data_variable(self): data = xr.Dataset({"temp": [0, 50]}) results: DataTree = DataTree(name="results", data=data) - xrt.assert_identical(results["temp"], data["temp"]) + assert_identical(results["temp"], data["temp"]) def test_getitem_single_data_variable_from_node(self): data = xr.Dataset({"temp": [0, 50]}) folder1: DataTree = DataTree(name="folder1") results: DataTree = DataTree(name="results", parent=folder1) DataTree(name="highres", parent=results, data=data) - xrt.assert_identical(folder1["results/highres/temp"], data["temp"]) + assert_identical(folder1["results/highres/temp"], data["temp"]) def test_getitem_nonexistent_node(self): folder1: DataTree = DataTree(name="folder1") @@ -210,7 +209,7 @@ def test_getitem_nonexistent_variable(self): def test_getitem_multiple_data_variables(self): data = xr.Dataset({"temp": [0, 50], "p": [5, 8, 7]}) results: DataTree = DataTree(name="results", data=data) - xrt.assert_identical(results[["temp", "p"]], data[["temp", "p"]]) # type: ignore[index] + assert_identical(results[["temp", "p"]], data[["temp", "p"]]) # type: ignore[index] @pytest.mark.xfail( reason="Indexing needs to return whole tree (GH https://github.com/xarray-contrib/datatree/issues/77)" @@ -218,7 +217,7 @@ def test_getitem_multiple_data_variables(self): def test_getitem_dict_like_selection_access_to_dataset(self): data = xr.Dataset({"temp": [0, 50]}) results: DataTree = DataTree(name="results", data=data) - xrt.assert_identical(results[{"temp": 1}], data[{"temp": 1}]) # type: ignore[index] + assert_identical(results[{"temp": 1}], data[{"temp": 1}]) # type: ignore[index] class TestUpdate: @@ -231,14 +230,14 @@ def test_update(self): print(dt._children) print(dt["a"]) print(expected) - dtt.assert_equal(dt, expected) + assert_equal(dt, expected) def test_update_new_named_dataarray(self): da = xr.DataArray(name="temp", data=[0, 50]) folder1: DataTree = DataTree(name="folder1") folder1.update({"results": da}) expected = da.rename("results") - xrt.assert_equal(folder1["results"], expected) + assert_equal(folder1["results"], expected) def test_update_doesnt_alter_child_name(self): dt: DataTree = DataTree() @@ -256,7 +255,7 @@ def test_update_overwrite(self): print(actual) print(expected) - dtt.assert_equal(actual, expected) + assert_equal(actual, expected) class TestCopy: @@ -267,7 +266,7 @@ def test_copy(self, create_test_datatree): node.attrs["Test"] = [1, 2, 3] for copied in [dt.copy(deep=False), copy(dt)]: - dtt.assert_identical(dt, copied) + assert_identical(dt, copied) for node, copied_node in zip(dt.root.subtree, copied.root.subtree): assert node.encoding == copied_node.encoding @@ -291,7 +290,7 @@ def test_copy_subtree(self): actual = dt["/level1/level2"].copy() expected = DataTree.from_dict({"/level3": xr.Dataset()}, name="level2") - dtt.assert_identical(actual, expected) + assert_identical(actual, expected) def test_deepcopy(self, create_test_datatree): dt = create_test_datatree() @@ -300,7 +299,7 @@ def test_deepcopy(self, create_test_datatree): node.attrs["Test"] = [1, 2, 3] for copied in [dt.copy(deep=True), deepcopy(dt)]: - dtt.assert_identical(dt, copied) + assert_identical(dt, copied) for node, copied_node in zip(dt.root.subtree, copied.root.subtree): assert node.encoding == copied_node.encoding @@ -331,7 +330,7 @@ def test_copy_with_data(self, create_test_datatree): expected = orig.copy() for k, v in new_data.items(): expected[k].data = v - dtt.assert_identical(expected, actual) + assert_identical(expected, actual) # TODO test parents and children? @@ -372,13 +371,13 @@ def test_setitem_new_empty_node(self): john["mary"] = DataTree() mary = john["mary"] assert isinstance(mary, DataTree) - xrt.assert_identical(mary.to_dataset(), xr.Dataset()) + assert_identical(mary.to_dataset(), xr.Dataset()) def test_setitem_overwrite_data_in_node_with_none(self): john: DataTree = DataTree(name="john") mary: DataTree = DataTree(name="mary", parent=john, data=xr.Dataset()) john["mary"] = DataTree() - xrt.assert_identical(mary.to_dataset(), xr.Dataset()) + assert_identical(mary.to_dataset(), xr.Dataset()) john.ds = xr.Dataset() # type: ignore[assignment] with pytest.raises(ValueError, match="has no name"): @@ -389,45 +388,45 @@ def test_setitem_dataset_on_this_node(self): data = xr.Dataset({"temp": [0, 50]}) results: DataTree = DataTree(name="results") results["."] = data - xrt.assert_identical(results.to_dataset(), data) + assert_identical(results.to_dataset(), data) @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_as_new_node(self): data = xr.Dataset({"temp": [0, 50]}) folder1: DataTree = DataTree(name="folder1") folder1["results"] = data - xrt.assert_identical(folder1["results"].to_dataset(), data) + assert_identical(folder1["results"].to_dataset(), data) @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_as_new_node_requiring_intermediate_nodes(self): data = xr.Dataset({"temp": [0, 50]}) folder1: DataTree = DataTree(name="folder1") folder1["results/highres"] = data - xrt.assert_identical(folder1["results/highres"].to_dataset(), data) + assert_identical(folder1["results/highres"].to_dataset(), data) def test_setitem_named_dataarray(self): da = xr.DataArray(name="temp", data=[0, 50]) folder1: DataTree = DataTree(name="folder1") folder1["results"] = da expected = da.rename("results") - xrt.assert_equal(folder1["results"], expected) + assert_equal(folder1["results"], expected) def test_setitem_unnamed_dataarray(self): data = xr.DataArray([0, 50]) folder1: DataTree = DataTree(name="folder1") folder1["results"] = data - xrt.assert_equal(folder1["results"], data) + assert_equal(folder1["results"], data) def test_setitem_variable(self): var = xr.Variable(data=[0, 50], dims="x") folder1: DataTree = DataTree(name="folder1") folder1["results"] = var - xrt.assert_equal(folder1["results"], xr.DataArray(var)) + assert_equal(folder1["results"], xr.DataArray(var)) def test_setitem_coerce_to_dataarray(self): folder1: DataTree = DataTree(name="folder1") folder1["results"] = 0 - xrt.assert_equal(folder1["results"], xr.DataArray(0)) + assert_equal(folder1["results"], xr.DataArray(0)) def test_setitem_add_new_variable_to_empty_node(self): results: DataTree = DataTree(name="results") @@ -449,7 +448,7 @@ def test_setitem_dataarray_replace_existing_node(self): p = xr.DataArray(data=[2, 3]) results["pressure"] = p expected = t.assign(pressure=p) - xrt.assert_identical(results.to_dataset(), expected) + assert_identical(results.to_dataset(), expected) class TestDictionaryInterface: ... @@ -462,16 +461,16 @@ def test_data_in_root(self): assert dt.name is None assert dt.parent is None assert dt.children == {} - xrt.assert_identical(dt.to_dataset(), dat) + assert_identical(dt.to_dataset(), dat) def test_one_layer(self): dat1, dat2 = xr.Dataset({"a": 1}), xr.Dataset({"b": 2}) dt = DataTree.from_dict({"run1": dat1, "run2": dat2}) - xrt.assert_identical(dt.to_dataset(), xr.Dataset()) + assert_identical(dt.to_dataset(), xr.Dataset()) assert dt.name is None - xrt.assert_identical(dt["run1"].to_dataset(), dat1) + assert_identical(dt["run1"].to_dataset(), dat1) assert dt["run1"].children == {} - xrt.assert_identical(dt["run2"].to_dataset(), dat2) + assert_identical(dt["run2"].to_dataset(), dat2) assert dt["run2"].children == {} def test_two_layers(self): @@ -480,13 +479,13 @@ def test_two_layers(self): assert "highres" in dt.children assert "lowres" in dt.children highres_run = dt["highres/run"] - xrt.assert_identical(highres_run.to_dataset(), dat1) + assert_identical(highres_run.to_dataset(), dat1) def test_nones(self): dt = DataTree.from_dict({"d": None, "d/e": None}) assert [node.name for node in dt.subtree] == [None, "d", "e"] assert [node.path for node in dt.subtree] == ["/", "/d", "/d/e"] - xrt.assert_identical(dt["d/e"].to_dataset(), xr.Dataset()) + assert_identical(dt["d/e"].to_dataset(), xr.Dataset()) def test_full(self, simple_datatree): dt = simple_datatree @@ -508,7 +507,7 @@ def test_datatree_values(self): actual = DataTree.from_dict({"a": dat1}) - dtt.assert_identical(actual, expected) + assert_identical(actual, expected) def test_roundtrip(self, simple_datatree): dt = simple_datatree @@ -587,16 +586,16 @@ def test_attribute_access(self, create_test_datatree): # vars / coords for key in ["a", "set0"]: - xrt.assert_equal(dt[key], getattr(dt, key)) + assert_equal(dt[key], getattr(dt, key)) assert key in dir(dt) # dims - xrt.assert_equal(dt["a"]["y"], getattr(dt.a, "y")) + assert_equal(dt["a"]["y"], getattr(dt.a, "y")) assert "y" in dir(dt["a"]) # children for key in ["set1", "set2", "set3"]: - dtt.assert_equal(dt[key], getattr(dt, key)) + assert_equal(dt[key], getattr(dt, key)) assert key in dir(dt) # attrs @@ -649,11 +648,11 @@ def test_assign(self): # kwargs form result = dt.assign(foo=xr.DataArray(0), a=DataTree()) - dtt.assert_equal(result, expected) + assert_equal(result, expected) # dict form result = dt.assign({"foo": xr.DataArray(0), "a": DataTree()}) - dtt.assert_equal(result, expected) + assert_equal(result, expected) class TestPipe: @@ -691,7 +690,7 @@ def f(x, tree, y): class TestSubset: def test_match(self): # TODO is this example going to cause problems with case sensitivity? - dt = DataTree.from_dict( + dt: DataTree = DataTree.from_dict( { "/a/A": None, "/a/B": None, @@ -706,10 +705,10 @@ def test_match(self): "/b/B": None, } ) - dtt.assert_identical(result, expected) + assert_identical(result, expected) def test_filter(self): - simpsons = DataTree.from_dict( + simpsons: DataTree = DataTree.from_dict( d={ "/": xr.Dataset({"age": 83}), "/Herbert": xr.Dataset({"age": 40}), @@ -729,4 +728,99 @@ def test_filter(self): name="Abe", ) elders = simpsons.filter(lambda node: node["age"].item() > 18) - dtt.assert_identical(elders, expected) + assert_identical(elders, expected) + + +class TestDSMethodInheritance: + def test_dataset_method(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}) + dt: DataTree = DataTree(data=ds) + DataTree(name="results", parent=dt, data=ds) + + expected: DataTree = DataTree(data=ds.isel(x=1)) + DataTree(name="results", parent=expected, data=ds.isel(x=1)) + + result = dt.isel(x=1) + assert_equal(result, expected) + + def test_reduce_method(self): + ds = xr.Dataset({"a": ("x", [False, True, False])}) + dt: DataTree = DataTree(data=ds) + DataTree(name="results", parent=dt, data=ds) + + expected: DataTree = DataTree(data=ds.any()) + DataTree(name="results", parent=expected, data=ds.any()) + + result = dt.any() + assert_equal(result, expected) + + def test_nan_reduce_method(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}) + dt: DataTree = DataTree(data=ds) + DataTree(name="results", parent=dt, data=ds) + + expected: DataTree = DataTree(data=ds.mean()) + DataTree(name="results", parent=expected, data=ds.mean()) + + result = dt.mean() + assert_equal(result, expected) + + def test_cum_method(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}) + dt: DataTree = DataTree(data=ds) + DataTree(name="results", parent=dt, data=ds) + + expected: DataTree = DataTree(data=ds.cumsum()) + DataTree(name="results", parent=expected, data=ds.cumsum()) + + result = dt.cumsum() + assert_equal(result, expected) + + +class TestOps: + def test_binary_op_on_int(self): + ds1 = xr.Dataset({"a": [5], "b": [3]}) + ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) + dt: DataTree = DataTree(data=ds1) + DataTree(name="subnode", data=ds2, parent=dt) + + expected: DataTree = DataTree(data=ds1 * 5) + DataTree(name="subnode", data=ds2 * 5, parent=expected) + + # TODO: Remove ignore when ops.py is migrated? + result: DataTree = dt * 5 # type: ignore[assignment,operator] + assert_equal(result, expected) + + def test_binary_op_on_dataset(self): + ds1 = xr.Dataset({"a": [5], "b": [3]}) + ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) + dt: DataTree = DataTree(data=ds1) + DataTree(name="subnode", data=ds2, parent=dt) + other_ds = xr.Dataset({"z": ("z", [0.1, 0.2])}) + + expected: DataTree = DataTree(data=ds1 * other_ds) + DataTree(name="subnode", data=ds2 * other_ds, parent=expected) + + result = dt * other_ds + assert_equal(result, expected) + + def test_binary_op_on_datatree(self): + ds1 = xr.Dataset({"a": [5], "b": [3]}) + ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) + dt: DataTree = DataTree(data=ds1) + DataTree(name="subnode", data=ds2, parent=dt) + + expected: DataTree = DataTree(data=ds1 * ds1) + DataTree(name="subnode", data=ds2 * ds2, parent=expected) + + # TODO: Remove ignore when ops.py is migrated? + result: DataTree = dt * dt # type: ignore[operator] + assert_equal(result, expected) + + +class TestUFuncs: + def test_tree(self, create_test_datatree): + dt = create_test_datatree() + expected = create_test_datatree(modify=lambda ds: np.sin(ds)) + result_tree = np.sin(dt) + assert_equal(result_tree, expected) diff --git a/xarray/tests/test_datatree_mapping.py b/xarray/tests/test_datatree_mapping.py index 16ca726759d..b8b55613c4a 100644 --- a/xarray/tests/test_datatree_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -8,7 +8,7 @@ check_isomorphic, map_over_subtree, ) -from xarray.datatree_.datatree.testing import assert_equal +from xarray.testing import assert_equal empty = xr.Dataset() diff --git a/xarray/tests/test_extensions.py b/xarray/tests/test_extensions.py index 7e1802246c7..7cfffd68620 100644 --- a/xarray/tests/test_extensions.py +++ b/xarray/tests/test_extensions.py @@ -5,9 +5,14 @@ import pytest import xarray as xr + +# TODO: Remove imports in favour of xr.DataTree etc, once part of public API +from xarray.core.datatree import DataTree +from xarray.core.extensions import register_datatree_accessor from xarray.tests import assert_identical +@register_datatree_accessor("example_accessor") @xr.register_dataset_accessor("example_accessor") @xr.register_dataarray_accessor("example_accessor") class ExampleAccessor: @@ -19,6 +24,7 @@ def __init__(self, xarray_obj): class TestAccessor: def test_register(self) -> None: + @register_datatree_accessor("demo") @xr.register_dataset_accessor("demo") @xr.register_dataarray_accessor("demo") class DemoAccessor: @@ -31,6 +37,9 @@ def __init__(self, xarray_obj): def foo(self): return "bar" + dt: DataTree = DataTree() + assert dt.demo.foo == "bar" + ds = xr.Dataset() assert ds.demo.foo == "bar" diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 6923d26b79a..256a02d49e2 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -10,6 +10,7 @@ import xarray as xr from xarray.core import formatting +from xarray.core.datatree import DataTree # TODO: Remove when can do xr.DataTree from xarray.tests import requires_cftime, requires_dask, requires_netCDF4 ON_WINDOWS = sys.platform == "win32" @@ -555,6 +556,108 @@ def test_array_scalar_format(self) -> None: format(var, ".2f") assert "Using format_spec is only supported" in str(excinfo.value) + def test_datatree_print_empty_node(self): + dt: DataTree = DataTree(name="root") + printout = dt.__str__() + assert printout == "DataTree('root', parent=None)" + + def test_datatree_print_empty_node_with_attrs(self): + dat = xr.Dataset(attrs={"note": "has attrs"}) + dt: DataTree = DataTree(name="root", data=dat) + printout = dt.__str__() + assert printout == dedent( + """\ + DataTree('root', parent=None) + Dimensions: () + Data variables: + *empty* + Attributes: + note: has attrs""" + ) + + def test_datatree_print_node_with_data(self): + dat = xr.Dataset({"a": [0, 2]}) + dt: DataTree = DataTree(name="root", data=dat) + printout = dt.__str__() + expected = [ + "DataTree('root', parent=None)", + "Dimensions", + "Coordinates", + "a", + "Data variables", + "*empty*", + ] + for expected_line, printed_line in zip(expected, printout.splitlines()): + assert expected_line in printed_line + + def test_datatree_printout_nested_node(self): + dat = xr.Dataset({"a": [0, 2]}) + root: DataTree = DataTree(name="root") + DataTree(name="results", data=dat, parent=root) + printout = root.__str__() + assert printout.splitlines()[2].startswith(" ") + + def test_datatree_repr_of_node_with_data(self): + dat = xr.Dataset({"a": [0, 2]}) + dt: DataTree = DataTree(name="root", data=dat) + assert "Coordinates" in repr(dt) + + def test_diff_datatree_repr_structure(self): + dt_1: DataTree = DataTree.from_dict({"a": None, "a/b": None, "a/c": None}) + dt_2: DataTree = DataTree.from_dict({"d": None, "d/e": None}) + + expected = dedent( + """\ + Left and right DataTree objects are not isomorphic + + Number of children on node '/a' of the left object: 2 + Number of children on node '/d' of the right object: 1""" + ) + actual = formatting.diff_datatree_repr(dt_1, dt_2, "isomorphic") + assert actual == expected + + def test_diff_datatree_repr_node_names(self): + dt_1: DataTree = DataTree.from_dict({"a": None}) + dt_2: DataTree = DataTree.from_dict({"b": None}) + + expected = dedent( + """\ + Left and right DataTree objects are not identical + + Node '/a' in the left object has name 'a' + Node '/b' in the right object has name 'b'""" + ) + actual = formatting.diff_datatree_repr(dt_1, dt_2, "identical") + assert actual == expected + + def test_diff_datatree_repr_node_data(self): + # casting to int64 explicitly ensures that int64s are created on all architectures + ds1 = xr.Dataset({"u": np.int64(0), "v": np.int64(1)}) + ds3 = xr.Dataset({"w": np.int64(5)}) + dt_1: DataTree = DataTree.from_dict({"a": ds1, "a/b": ds3}) + ds2 = xr.Dataset({"u": np.int64(0)}) + ds4 = xr.Dataset({"w": np.int64(6)}) + dt_2: DataTree = DataTree.from_dict({"a": ds2, "a/b": ds4}) + + expected = dedent( + """\ + Left and right DataTree objects are not equal + + + Data in nodes at position '/a' do not match: + + Data variables only on the left object: + v int64 8B 1 + + Data in nodes at position '/a/b' do not match: + + Differing data variables: + L w int64 8B 5 + R w int64 8B 6""" + ) + actual = formatting.diff_datatree_repr(dt_1, dt_2, "equals") + assert actual == expected + def test_inline_variable_array_repr_custom_repr() -> None: class CustomArray: