Skip to content

Commit

Permalink
Migrate treenode module. (#8757)
Browse files Browse the repository at this point in the history
* Update the formating tests

PR (#8702) added nbytes representation in DataArrays and Dataset repr, this adds it to the datatree tests.

* Migrate treenode module

Moves treenode.py and test_treenode.py.
Updates some typing.
Updates imports from treenode.

* Update NotFoundInTreeError description.

* Reformat some comments

Add test tree structure for easier understanding.

* Updates whats-new.rst

* mypy typing. (terrible?)

There must be a better way, but I don't know it.
particularly the list comprehension casts.

* Adds __repr__ to NamedNode and updates test

This test was broken becuase only the root node was being tested and none of
the previous nodes were represented in the __str__.

* Adds quotes to NamedNode __str__ representation.

* swaps " for ' in NamedNode __str__ representation.

* Adding Tom in so he gets blamed properly.

* resolve conflict whats-new.rst

Question is I did update below the released line to give Tom some credit.  I
hope that's is allowable.

* Moves test_treenode.py to xarray/tests.

Integrated tests.

* refactors backend tests for datatree IO

* Add explicit engine back in test_to_zarr

* Removes OrderedDict from treenode

* Renames tests/test_io.py -> tests/test_backends_datatree.py

* typo

* Add types

* Pass mypy for 3.9
  • Loading branch information
flamingbear authored Feb 27, 2024
1 parent e47eb92 commit dfdd631
Show file tree
Hide file tree
Showing 13 changed files with 262 additions and 240 deletions.
10 changes: 7 additions & 3 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ Documentation

Internal Changes
~~~~~~~~~~~~~~~~

- Migrates ``treenode`` functionality into ``xarray/core`` (:pull:`8757`)
By `Matt Savoie <https://github.com/flamingbear>`_ and `Tom Nicholas
<https://github.com/TomNicholas>`_.


.. _whats-new.2024.02.0:
Expand Down Expand Up @@ -145,9 +147,11 @@ Internal Changes
``xarray/namedarray``. (:pull:`8319`)
By `Tom Nicholas <https://github.com/TomNicholas>`_ and `Anderson Banihirwe <https://github.com/andersy005>`_.
- Imports ``datatree`` repository and history into internal location. (:pull:`8688`)
By `Matt Savoie <https://github.com/flamingbear>`_ and `Justus Magin <https://github.com/keewis>`_.
By `Matt Savoie <https://github.com/flamingbear>`_, `Justus Magin <https://github.com/keewis>`_
and `Tom Nicholas <https://github.com/TomNicholas>`_.
- Adds :py:func:`open_datatree` into ``xarray/backends`` (:pull:`8697`)
By `Matt Savoie <https://github.com/flamingbear>`_.
By `Matt Savoie <https://github.com/flamingbear>`_ and `Tom Nicholas
<https://github.com/TomNicholas>`_.
- Refactor :py:meth:`xarray.core.indexing.DaskIndexingAdapter.__getitem__` to remove an unnecessary
rewrite of the indexer key (:issue: `8377`, :pull:`8758`)
By `Anderson Banihirwe <https://github.com/andersy005>`_.
Expand Down
4 changes: 2 additions & 2 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def _open_datatree_netcdf(
**kwargs,
) -> DataTree:
from xarray.backends.api import open_dataset
from xarray.core.treenode import NodePath
from xarray.datatree_.datatree import DataTree
from xarray.datatree_.datatree.treenode import NodePath

ds = open_dataset(filename_or_obj, **kwargs)
tree_root = DataTree.from_dict({"/": ds})
Expand All @@ -159,7 +159,7 @@ def _open_datatree_netcdf(


def _iter_nc_groups(root, parent="/"):
from xarray.datatree_.datatree.treenode import NodePath
from xarray.core.treenode import NodePath

parent = NodePath(parent)
for path, group in root.groups.items():
Expand Down
4 changes: 2 additions & 2 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,8 +1048,8 @@ def open_datatree(
import zarr

from xarray.backends.api import open_dataset
from xarray.core.treenode import NodePath
from xarray.datatree_.datatree import DataTree
from xarray.datatree_.datatree.treenode import NodePath

zds = zarr.open_group(filename_or_obj, mode="r")
ds = open_dataset(filename_or_obj, engine="zarr", **kwargs)
Expand All @@ -1075,7 +1075,7 @@ def open_datatree(


def _iter_zarr_groups(root, parent="/"):
from xarray.datatree_.datatree.treenode import NodePath
from xarray.core.treenode import NodePath

parent = NodePath(parent)
for path, group in root.groups():
Expand Down
100 changes: 50 additions & 50 deletions xarray/datatree_/datatree/treenode.py → xarray/core/treenode.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from __future__ import annotations

import sys
from collections import OrderedDict
from collections.abc import Iterator, Mapping
from pathlib import PurePosixPath
from typing import (
TYPE_CHECKING,
Generic,
Iterator,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
)

from xarray.core.utils import Frozen, is_dict_like
Expand All @@ -25,7 +20,7 @@ class InvalidTreeError(Exception):


class NotFoundInTreeError(ValueError):
"""Raised when operation can't be completed because one node is part of the expected tree."""
"""Raised when operation can't be completed because one node is not part of the expected tree."""


class NodePath(PurePosixPath):
Expand Down Expand Up @@ -55,8 +50,8 @@ class TreeNode(Generic[Tree]):
This class stores no data, it has only parents and children attributes, and various methods.
Stores child nodes in an Ordered Dictionary, which is necessary to ensure that equality checks between two trees
also check that the order of child nodes is the same.
Stores child nodes in an dict, ensuring that equality checks between trees
and order of child nodes is preserved (since python 3.7).
Nodes themselves are intrinsically unnamed (do not possess a ._name attribute), but if the node has a parent you can
find the key it is stored under via the .name property.
Expand All @@ -73,15 +68,16 @@ class TreeNode(Generic[Tree]):
Also allows access to any other node in the tree via unix-like paths, including upwards referencing via '../'.
(This class is heavily inspired by the anytree library's NodeMixin class.)
"""

_parent: Optional[Tree]
_children: OrderedDict[str, Tree]
_parent: Tree | None
_children: dict[str, Tree]

def __init__(self, children: Optional[Mapping[str, Tree]] = None):
def __init__(self, children: Mapping[str, Tree] | None = None):
"""Create a parentless node."""
self._parent = None
self._children = OrderedDict()
self._children = {}
if children is not None:
self.children = children

Expand All @@ -91,7 +87,7 @@ def parent(self) -> Tree | None:
return self._parent

def _set_parent(
self, new_parent: Tree | None, child_name: Optional[str] = None
self, new_parent: Tree | None, child_name: str | None = None
) -> None:
# TODO is it possible to refactor in a way that removes this private method?

Expand Down Expand Up @@ -127,17 +123,15 @@ def _detach(self, parent: Tree | None) -> None:
if parent is not None:
self._pre_detach(parent)
parents_children = parent.children
parent._children = OrderedDict(
{
name: child
for name, child in parents_children.items()
if child is not self
}
)
parent._children = {
name: child
for name, child in parents_children.items()
if child is not self
}
self._parent = None
self._post_detach(parent)

def _attach(self, parent: Tree | None, child_name: Optional[str] = None) -> None:
def _attach(self, parent: Tree | None, child_name: str | None = None) -> None:
if parent is not None:
if child_name is None:
raise ValueError(
Expand Down Expand Up @@ -167,7 +161,7 @@ def children(self: Tree) -> Mapping[str, Tree]:
@children.setter
def children(self: Tree, children: Mapping[str, Tree]) -> None:
self._check_children(children)
children = OrderedDict(children)
children = {**children}

old_children = self.children
del self.children
Expand Down Expand Up @@ -242,7 +236,7 @@ def _iter_parents(self: Tree) -> Iterator[Tree]:
yield node
node = node.parent

def iter_lineage(self: Tree) -> Tuple[Tree, ...]:
def iter_lineage(self: Tree) -> tuple[Tree, ...]:
"""Iterate up the tree, starting from the current node."""
from warnings import warn

Expand All @@ -254,7 +248,7 @@ def iter_lineage(self: Tree) -> Tuple[Tree, ...]:
return tuple((self, *self.parents))

@property
def lineage(self: Tree) -> Tuple[Tree, ...]:
def lineage(self: Tree) -> tuple[Tree, ...]:
"""All parent nodes and their parent nodes, starting with the closest."""
from warnings import warn

Expand All @@ -266,12 +260,12 @@ def lineage(self: Tree) -> Tuple[Tree, ...]:
return self.iter_lineage()

@property
def parents(self: Tree) -> Tuple[Tree, ...]:
def parents(self: Tree) -> tuple[Tree, ...]:
"""All parent nodes and their parent nodes, starting with the closest."""
return tuple(self._iter_parents())

@property
def ancestors(self: Tree) -> Tuple[Tree, ...]:
def ancestors(self: Tree) -> tuple[Tree, ...]:
"""All parent nodes and their parent nodes, starting with the most distant."""

from warnings import warn
Expand Down Expand Up @@ -306,7 +300,7 @@ def is_leaf(self) -> bool:
return self.children == {}

@property
def leaves(self: Tree) -> Tuple[Tree, ...]:
def leaves(self: Tree) -> tuple[Tree, ...]:
"""
All leaf nodes.
Expand All @@ -315,20 +309,18 @@ def leaves(self: Tree) -> Tuple[Tree, ...]:
return tuple([node for node in self.subtree if node.is_leaf])

@property
def siblings(self: Tree) -> OrderedDict[str, Tree]:
def siblings(self: Tree) -> dict[str, Tree]:
"""
Nodes with the same parent as this node.
"""
if self.parent:
return OrderedDict(
{
name: child
for name, child in self.parent.children.items()
if child is not self
}
)
return {
name: child
for name, child in self.parent.children.items()
if child is not self
}
else:
return OrderedDict()
return {}

@property
def subtree(self: Tree) -> Iterator[Tree]:
Expand All @@ -341,12 +333,12 @@ def subtree(self: Tree) -> Iterator[Tree]:
--------
DataTree.descendants
"""
from . import iterators
from xarray.datatree_.datatree import iterators

return iterators.PreOrderIter(self)

@property
def descendants(self: Tree) -> Tuple[Tree, ...]:
def descendants(self: Tree) -> tuple[Tree, ...]:
"""
Child nodes and all their child nodes.
Expand Down Expand Up @@ -431,7 +423,7 @@ def _post_attach(self: Tree, parent: Tree) -> None:
"""Method call after attaching to `parent`."""
pass

def get(self: Tree, key: str, default: Optional[Tree] = None) -> Optional[Tree]:
def get(self: Tree, key: str, default: Tree | None = None) -> Tree | None:
"""
Return the child node with the specified key.
Expand All @@ -445,7 +437,7 @@ def get(self: Tree, key: str, default: Optional[Tree] = None) -> Optional[Tree]:

# TODO `._walk` method to be called by both `_get_item` and `_set_item`

def _get_item(self: Tree, path: str | NodePath) -> Union[Tree, T_DataArray]:
def _get_item(self: Tree, path: str | NodePath) -> Tree | T_DataArray:
"""
Returns the object lying at the given path.
Expand Down Expand Up @@ -488,24 +480,26 @@ def _set(self: Tree, key: str, val: Tree) -> None:
def _set_item(
self: Tree,
path: str | NodePath,
item: Union[Tree, T_DataArray],
item: Tree | T_DataArray,
new_nodes_along_path: bool = False,
allow_overwrite: bool = True,
) -> None:
"""
Set a new item in the tree, overwriting anything already present at that path.
The given value either forms a new node of the tree or overwrites an existing item at that location.
The given value either forms a new node of the tree or overwrites an
existing item at that location.
Parameters
----------
path
item
new_nodes_along_path : bool
If true, then if necessary new nodes will be created along the given path, until the tree can reach the
specified location.
If true, then if necessary new nodes will be created along the
given path, until the tree can reach the specified location.
allow_overwrite : bool
Whether or not to overwrite any existing node at the location given by path.
Whether or not to overwrite any existing node at the location given
by path.
Raises
------
Expand Down Expand Up @@ -580,9 +574,9 @@ class NamedNode(TreeNode, Generic[Tree]):
Implements path-like relationships to other nodes in its tree.
"""

_name: Optional[str]
_parent: Optional[Tree]
_children: OrderedDict[str, Tree]
_name: str | None
_parent: Tree | None
_children: dict[str, Tree]

def __init__(self, name=None, children=None):
super().__init__(children=children)
Expand All @@ -603,8 +597,14 @@ def name(self, name: str | None) -> None:
raise ValueError("node names cannot contain forward slashes")
self._name = name

def __repr__(self, level=0):
repr_value = "\t" * level + self.__str__() + "\n"
for child in self.children:
repr_value += self.get(child).__repr__(level + 1)
return repr_value

def __str__(self) -> str:
return f"NamedNode({self.name})" if self.name else "NamedNode()"
return f"NamedNode('{self.name}')" if self.name else "NamedNode()"

def _post_attach(self: NamedNode, parent: NamedNode) -> None:
"""Ensures child has name attribute corresponding to key under which it has been stored."""
Expand Down
2 changes: 1 addition & 1 deletion xarray/datatree_/datatree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .datatree import DataTree
from .extensions import register_datatree_accessor
from .mapping import TreeIsomorphismError, map_over_subtree
from .treenode import InvalidTreeError, NotFoundInTreeError
from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError


__all__ = (
Expand Down
2 changes: 1 addition & 1 deletion xarray/datatree_/datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
MappedDataWithCoords,
)
from .render import RenderTree
from .treenode import NamedNode, NodePath, Tree
from xarray.core.treenode import NamedNode, NodePath, Tree

try:
from xarray.core.variable import calculate_dimensions
Expand Down
2 changes: 1 addition & 1 deletion xarray/datatree_/datatree/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import abc
from typing import Callable, Iterator, List, Optional

from .treenode import Tree
from xarray.core.treenode import Tree

"""These iterators are copied from anytree.iterators, with minor modifications."""

Expand Down
4 changes: 2 additions & 2 deletions xarray/datatree_/datatree/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from xarray import DataArray, Dataset

from .iterators import LevelOrderIter
from .treenode import NodePath, TreeNode
from xarray.core.treenode import NodePath, TreeNode

if TYPE_CHECKING:
from .datatree import DataTree
from xarray.core.datatree import DataTree


class TreeIsomorphismError(ValueError):
Expand Down
6 changes: 3 additions & 3 deletions xarray/datatree_/datatree/tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,13 @@ def test_diff_node_data(self):
Data in nodes at position '/a' do not match:
Data variables only on the left object:
v int64 1
v int64 8B 1
Data in nodes at position '/a/b' do not match:
Differing data variables:
L w int64 5
R w int64 6"""
L w int64 8B 5
R w int64 8B 6"""
)
actual = diff_tree_repr(dt_1, dt_2, "equals")
assert actual == expected
Loading

0 comments on commit dfdd631

Please sign in to comment.