Skip to content
This repository has been archived by the owner on Oct 24, 2024. It is now read-only.

Method to reorder nodes #271

Closed
wants to merge 13 commits into from
117 changes: 116 additions & 1 deletion datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import copy
import itertools
import re
from collections import OrderedDict
from html import escape
from typing import (
Expand Down Expand Up @@ -79,6 +80,74 @@
T_Path = Union[str, NodePath]


_SYMBOLIC_NODE_NAME = r"\w+"
_SYMBOLIC_NODEPATH = f"\/?{_SYMBOLIC_NODE_NAME}(\/{_SYMBOLIC_NODE_NAME})*\/?"
_SYMBOLIC_REORDERING = f"^{_SYMBOLIC_NODEPATH}->{_SYMBOLIC_NODEPATH}$"


def _parse_symbolic_ordering(ordering: str) -> Tuple[List[str], List[str]]:
"""Parse a symbolic reordering string of the form 'a/b -> b/a'."""
if not re.match(_SYMBOLIC_REORDERING, ordering):
raise ValueError(f"Invalid symbolic reordering: {ordering}")

in_txt, out_txt = ordering.split("->")
old_symbolic_order = re.findall(_SYMBOLIC_NODE_NAME, in_txt)
new_symbolic_order = re.findall(_SYMBOLIC_NODE_NAME, out_txt)

# Check number of symbols is the same on both sides
if len(old_symbolic_order) != len(new_symbolic_order):
raise ValueError(
"Invalid symbolic reordering. The depth of the symbolic path on each side must be equal, "
f"but the left has {len(old_symbolic_order)} parts and the right has {len(new_symbolic_order)}"
f" parts."
)

# Check every symbol appears only once
if len(set(old_symbolic_order)) < len(old_symbolic_order):
# TODO
repeated_symbols = ...
raise ValueError(
"Invalid symbolic reordering. Each symbol must appear only once on each side, "
f"but the symbols {repeated_symbols} appear more than once in the left-hand side."
)
if len(set(new_symbolic_order)) < len(new_symbolic_order):
# TODO
repeated_symbols = ...
raise ValueError(
"Invalid symbolic reordering. Each symbol must appear only once on each side, "
f"but the symbols {repeated_symbols} appear more than once in the right-hand side."
)

# Check every symbol appears on both sides
all_symbols = set(old_symbolic_order).union(set(new_symbolic_order))
if len(set(old_symbolic_order)) < len(all_symbols):
unmatched_symbols = all_symbols - set(old_symbolic_order)
raise ValueError(
"Invalid symbolic reordering. Every symbol must be present on both sides, but"
f"the symbols {unmatched_symbols} are only present on the right-hand side."
)
if len(set(new_symbolic_order)) < len(all_symbols):
unmatched_symbols = all_symbols - set(new_symbolic_order)
raise ValueError(
"Invalid symbolic reordering. Every symbol must be present on both sides, but"
f"the symbols {unmatched_symbols} are only present on the left-hand side."
)

return old_symbolic_order, new_symbolic_order


def _reorder_path(path: str, old_order: List[str], new_order: List[str]) -> str:
"""Re-orders the parts of the given path from old_order to match new_order."""

parts = NodePath(path).parts
if len(old_order) > len(parts):
raise ValueError(f"node {path} only has depth {len(parts)}")

new_order_indices = [new_order.index(el) for el in old_order]
reordered_parts = [parts[i] for i in new_order_indices]
return str(NodePath(reordered_parts))


def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset:
if isinstance(data, DataArray):
ds = data.to_dataset()
Expand Down Expand Up @@ -1302,6 +1371,52 @@ def match(self, pattern: str) -> DataTree:
}
return DataTree.from_dict(matching_nodes, name=self.root.name)

def reorder(self, ordering: str) -> DataTree:
"""
Reorder levels of all nodes in this subtree by rearranging the parts of each of their paths.

In general this operation will preserve the depth of every node (and hence depth of the whole subtree),
but will not preserve the width at any level.

Parameters
----------
ordering: str
String specifying symbolically how to reorder levels of each path, for example:
'a/b/c -> b/c/a'

Generally must be of the form:
'{OLD_ORDER} -> {NEW_ORDER}'
where OLD_ORDER = 'a/b/***/y/z', representing a symbolic ordering of the parts of the node path,
and NEW_ORDER = 'z/a/***/b/y', representing an arbitrary re-ordering of the same number of parts.
(Here the triple asterisk stands in for an arbitrary number of parts.)

Symbols must be unique, and each symbol in the old order must have a corresponding entry in the new order,
so the number of symbols must be the same in the new order as in the old order.

By default paths will be re-ordered starting at the root. To re-order at the leaves instead, an ellipsis can
be pre-prended, e.g. '.../a/b -> .../b/a'. The ellipsis can be present in the new order, old order, both,
or neither. (Ellipses will have no effect on a node which has a depth equal to the number of symbols.)

Returns
-------
reordered: DataTree
DataTree object where each node has the same depth as it did originally.

Examples
--------
"""
old_symbolic_order, new_symbolic_order = _parse_symbolic_ordering(ordering)

# only re-order the subtree, and return a new copy, to avoid messing up parents of this node
reordered_dict = {
_reorder_path(
node.relative_to(self), old_symbolic_order, new_symbolic_order
): node.ds
for node in self.subtree
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
}

return DataTree.from_dict(reordered_dict)

def map_over_subtree(
self,
func: Callable,
Expand Down Expand Up @@ -1477,7 +1592,7 @@ def to_netcdf(
Note that unlimited_dims may also be set via
``dataset.encoding["unlimited_dims"]``.
kwargs :
Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf``
Additional keyword arguments to be passed to ``xarray.Dataset.to_netcdf``
"""
from .io import _datatree_to_netcdf

Expand Down
40 changes: 40 additions & 0 deletions datatree/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,46 @@ def test_assign(self):
dtt.assert_equal(result, expected)


class TestReorder:
@pytest.mark.parametrize(
"in_dict, order, expected",
[
({"a": None}, "a -> a", {"a": None}),
({"a/b": None}, "a/b -> b/a", {"b/a": None}),
],
)
def test_reorder(self, in_dict, order, expected):
dt = DataTree.from_dict(in_dict)
out_dict = dt.reorder("a/b -> b/a").to_dict()
assert out_dict == expected

def test_invalid_order(self):
dt = DataTree.from_dict({"A/B/C": None})

with pytest.raises(ValueError, match="Invalid symbolic reordering"):
dt.reorder("a")

with pytest.raises(ValueError, match="Invalid symbolic reordering"):
dt.reorder("a->")

with pytest.raises(
ValueError, match="depth of the symbolic path on each side must be equal"
):
dt.reorder("a->a/b")

with pytest.raises(ValueError, match="must be present on both sides"):
dt.reorder("a->b")

with pytest.raises(ValueError, match="must appear only once"):
dt.reorder("a/a/b->b/a/a")

def test_not_deep_enough(self):
dt = DataTree.from_dict({"A/B/C": None})

with pytest.raises(ValueError, match="node X only has depth Y"):
dt.reorder("a/b->b/a")


class TestPipe:
def test_noop(self, create_test_datatree):
dt = create_test_datatree()
Expand Down
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ For manipulating, traversing, navigating, or mapping over the tree structure.
DataTree.pipe
DataTree.match
DataTree.filter
DataTree.reorder

DataTree Contents
-----------------
Expand Down
3 changes: 3 additions & 0 deletions docs/source/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ v0.0.13 (unreleased)
New Features
~~~~~~~~~~~~

- New :py:meth:`DataTree.reorder` method for re-ordering levels of all nodes in the tree according to a
symbolic pattern such as ``a/b->b/a``. (:pull:`271`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- New :py:meth:`DataTree.match` method for glob-like pattern matching of node paths. (:pull:`267`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- Indicate which node caused the problem if error encountered while applying user function using :py:func:`map_over_subtree`
Expand Down
Loading