Skip to content

Commit

Permalink
Re-implement map_over_datasets
Browse files Browse the repository at this point in the history
The main changes:

- It is implemented using zip_subtrees, which means it should properly
  handle DataTrees where the nodes are defined in a different order.
- For simplicity, I removed handling of `**kwargs`, in order to preserve
  some flexibility for adding keyword arugments.
- I removed automatic skipping of empty nodes, because there are almost
  assuredly cases where that would make sense. This could be restored
  with a option keyword arugment.
  • Loading branch information
shoyer committed Oct 15, 2024
1 parent 4480e11 commit 739573a
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 190 deletions.
5 changes: 1 addition & 4 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,7 +1387,6 @@ def map_over_datasets(
self,
func: Callable,
*args: Iterable[Any],
**kwargs: Any,
) -> DataTree | tuple[DataTree, ...]:
"""
Apply a function to every dataset in this subtree, returning a new tree which stores the results.
Expand All @@ -1406,8 +1405,6 @@ def map_over_datasets(
Function will not be applied to any nodes without datasets.
*args : tuple, optional
Positional arguments passed on to `func`.
**kwargs : Any
Keyword arguments passed on to `func`.
Returns
-------
Expand All @@ -1417,7 +1414,7 @@ def map_over_datasets(
# TODO this signature means that func has no way to know which node it is being called upon - change?

# TODO fix this typing error
return map_over_datasets(func)(self, *args, **kwargs)
return map_over_datasets(func)(self, *args)

def pipe(
self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any
Expand Down
221 changes: 78 additions & 143 deletions xarray/core/datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

import functools
import sys
from collections.abc import Callable
from itertools import repeat
from typing import TYPE_CHECKING
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, cast

from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.formatting import diff_treestructure
from xarray.core.treenode import NodePath, TreeNode
from xarray.core.treenode import TreeNode, zip_subtrees

if TYPE_CHECKING:
from xarray.core.datatree import DataTree
Expand Down Expand Up @@ -125,110 +123,55 @@ def map_over_datasets(func: Callable) -> Callable:
# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?

@functools.wraps(func)
def _map_over_datasets(*args, **kwargs) -> DataTree | tuple[DataTree, ...]:
def _map_over_datasets(*args) -> DataTree | tuple[DataTree, ...]:
"""Internal function which maps func over every node in tree, returning a tree of the results."""
from xarray.core.datatree import DataTree

all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [
a for a in kwargs.values() if isinstance(a, DataTree)
]

if len(all_tree_inputs) > 0:
first_tree, *other_trees = all_tree_inputs
else:
raise TypeError("Must pass at least one tree object")

for other_tree in other_trees:
# isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic
check_isomorphic(
first_tree, other_tree, require_names_equal=False, check_from_root=False
)

# Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
# We don't know which arguments are DataTrees so we zip all arguments together as iterables
# Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
out_data_objects = {}
args_as_tree_length_iterables = [
a.subtree if isinstance(a, DataTree) else repeat(a) for a in args
]
n_args = len(args_as_tree_length_iterables)
kwargs_as_tree_length_iterables = {
k: v.subtree if isinstance(v, DataTree) else repeat(v)
for k, v in kwargs.items()
}
for node_of_first_tree, *all_node_args in zip(
first_tree.subtree,
*args_as_tree_length_iterables,
*list(kwargs_as_tree_length_iterables.values()),
strict=False,
):
node_args_as_datasetviews = [
a.dataset if isinstance(a, DataTree) else a
for a in all_node_args[:n_args]
]
node_kwargs_as_datasetviews = dict(
zip(
[k for k in kwargs_as_tree_length_iterables.keys()],
[
v.dataset if isinstance(v, DataTree) else v
for v in all_node_args[n_args:]
],
strict=True,
)
out_data_objects: dict[str, Dataset | None | tuple[Dataset | None, ...]] = {}

tree_args = [arg for arg in args if isinstance(arg, DataTree)]
subtrees = zip_subtrees(*tree_args)

for node_tree_args in subtrees:

node_dataset_args = [arg.dataset for arg in node_tree_args]
for i, arg in enumerate(args):
if not isinstance(arg, DataTree):
node_dataset_args.insert(i, arg)

path = (
"/"
if node_tree_args[0] is tree_args[0]
else node_tree_args[0].relative_to(tree_args[0])
)
func_with_error_context = _handle_errors_with_path_context(
node_of_first_tree.path
)(func)

if node_of_first_tree.has_data:
# call func on the data in this particular set of corresponding nodes
results = func_with_error_context(
*node_args_as_datasetviews, **node_kwargs_as_datasetviews
)
elif node_of_first_tree.has_attrs:
# propagate attrs
results = node_of_first_tree.dataset
else:
# nothing to propagate so use fastpath to create empty node in new tree
results = None
func_with_error_context = _handle_errors_with_path_context(path)(func)
results = func_with_error_context(*node_dataset_args)

# TODO implement mapping over multiple trees in-place using if conditions from here on?
out_data_objects[node_of_first_tree.path] = results
out_data_objects[path] = results

# Find out how many return values we received
num_return_values = _check_all_return_values(out_data_objects)

# Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees
original_root_path = first_tree.path
result_trees = []
for i in range(num_return_values):
out_tree_contents = {}
for n in first_tree.subtree:
p = n.path
if p in out_data_objects.keys():
if isinstance(out_data_objects[p], tuple):
output_node_data = out_data_objects[p][i]
else:
output_node_data = out_data_objects[p]
else:
output_node_data = None

# Discard parentage so that new trees don't include parents of input nodes
relative_path = str(NodePath(p).relative_to(original_root_path))
relative_path = "/" if relative_path == "." else relative_path
out_tree_contents[relative_path] = output_node_data

new_tree = DataTree.from_dict(
out_tree_contents,
name=first_tree.name,
)
result_trees.append(new_tree)
if num_return_values is None:
out_data = cast(Mapping[str, Dataset | None], out_data_objects)
return DataTree.from_dict(out_data, name=tree_args[0].name)

# If only one result then don't wrap it in a tuple
if len(result_trees) == 1:
return result_trees[0]
else:
return tuple(result_trees)
out_data_tuples = cast(
Mapping[str, tuple[Dataset | None, ...]], out_data_objects
)
output_dicts: list[dict[str, Dataset | None]] = [
{} for _ in range(num_return_values)
]
for path, outputs in out_data_tuples.items():
for output_dict, output in zip(output_dicts, outputs, strict=False):
output_dict[path] = output

return tuple(
DataTree.from_dict(output_dict, name=tree_args[0].name)
for output_dict in output_dicts
)

return _map_over_datasets

Expand Down Expand Up @@ -260,62 +203,54 @@ def add_note(err: BaseException, msg: str) -> None:
err.add_note(msg)


def _check_single_set_return_values(
path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray]
):
def _check_single_set_return_values(path_to_node: str, obj: Any) -> int | None:
"""Check types returned from single evaluation of func, and return number of return values received from func."""
if isinstance(obj, Dataset | DataArray):
return 1
elif isinstance(obj, tuple):
for r in obj:
if not isinstance(r, Dataset | DataArray):
raise TypeError(
f"One of the results of calling func on datasets on the nodes at position {path_to_node} is "
f"of type {type(r)}, not Dataset or DataArray."
)
return len(obj)
else:
if isinstance(obj, None | Dataset):
return None # no need to pack results

if not isinstance(obj, tuple) or not all(
isinstance(r, Dataset | None) for r in obj
):
raise TypeError(
f"The result of calling func on the node at position {path_to_node} is of type {type(obj)}, not "
f"Dataset or DataArray, nor a tuple of such types."
f"the result of calling func on the node at position is not a Dataset or None "
f"or a tuple of such types: {obj!r}"
)

return len(obj)

def _check_all_return_values(returned_objects):
"""Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types."""

if all(r is None for r in returned_objects.values()):
raise TypeError(
"Called supplied function on all nodes but found a return value of None for"
"all of them."
)
def _check_all_return_values(returned_objects) -> int | None:
"""Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types."""

result_data_objects = [
(path_to_node, r)
for path_to_node, r in returned_objects.items()
if r is not None
(path_to_node, r) for path_to_node, r in returned_objects.items()
]

if len(result_data_objects) == 1:
# Only one node in the tree: no need to check consistency of results between nodes
path_to_node, result = result_data_objects[0]
num_return_values = _check_single_set_return_values(path_to_node, result)
else:
prev_path, _ = result_data_objects[0]
prev_num_return_values, num_return_values = None, None
for path_to_node, obj in result_data_objects[1:]:
num_return_values = _check_single_set_return_values(path_to_node, obj)

if (
num_return_values != prev_num_return_values
and prev_num_return_values is not None
):
first_path, result = result_data_objects[0]
return_values = _check_single_set_return_values(first_path, result)

for path_to_node, obj in result_data_objects[1:]:
cur_return_values = _check_single_set_return_values(path_to_node, obj)

if return_values != cur_return_values:
if return_values is None:
raise TypeError(
f"Calling func on the nodes at position {path_to_node} returns {num_return_values} separate return "
f"values, whereas calling func on the nodes at position {prev_path} instead returns "
f"{prev_num_return_values} separate return values."
f"Calling func on the nodes at position {path_to_node} returns "
f"a tuple of {cur_return_values} datasets, whereas calling func on the "
f"nodes at position {first_path} instead returns a single dataset."
)
elif cur_return_values is None:
raise TypeError(
f"Calling func on the nodes at position {path_to_node} returns "
f"a single dataset, whereas calling func on the nodes at position "
f"{first_path} instead returns a tuple of {return_values} datasets."
)
else:
raise TypeError(
f"Calling func on the nodes at position {path_to_node} returns "
f"a tuple of {cur_return_values} datasets, whereas calling func on "
f"the nodes at position {first_path} instead returns a tuple of "
f"{return_values} datasets."
)

prev_path, prev_num_return_values = path_to_node, num_return_values

return num_return_values
return return_values
3 changes: 3 additions & 0 deletions xarray/core/treenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,9 @@ def zip_subtrees(*trees: AnyNamedNode) -> Iterator[tuple[AnyNamedNode, ...]]:
------
Tuples of matching subtrees.
"""
if not trees:
raise TypeError("Must pass at least one tree object")

# https://en.wikipedia.org/wiki/Breadth-first_search#Pseudocode
queue = collections.deque([trees])

Expand Down
Loading

0 comments on commit 739573a

Please sign in to comment.