diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index e503b5c0741..a1351364a3e 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1387,7 +1387,6 @@ def map_over_datasets( self, func: Callable, *args: Iterable[Any], - **kwargs: Any, ) -> DataTree | tuple[DataTree, ...]: """ Apply a function to every dataset in this subtree, returning a new tree which stores the results. @@ -1406,8 +1405,6 @@ def map_over_datasets( Function will not be applied to any nodes without datasets. *args : tuple, optional Positional arguments passed on to `func`. - **kwargs : Any - Keyword arguments passed on to `func`. Returns ------- @@ -1417,7 +1414,7 @@ def map_over_datasets( # TODO this signature means that func has no way to know which node it is being called upon - change? # TODO fix this typing error - return map_over_datasets(func)(self, *args, **kwargs) + return map_over_datasets(func)(self, *args) def pipe( self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any diff --git a/xarray/core/datatree_mapping.py b/xarray/core/datatree_mapping.py index 2817effa856..f11e5e7bb0b 100644 --- a/xarray/core/datatree_mapping.py +++ b/xarray/core/datatree_mapping.py @@ -2,14 +2,12 @@ import functools import sys -from collections.abc import Callable -from itertools import repeat -from typing import TYPE_CHECKING +from collections.abc import Callable, Mapping +from typing import TYPE_CHECKING, Any, cast -from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset from xarray.core.formatting import diff_treestructure -from xarray.core.treenode import NodePath, TreeNode +from xarray.core.treenode import TreeNode, zip_subtrees if TYPE_CHECKING: from xarray.core.datatree import DataTree @@ -125,110 +123,55 @@ def map_over_datasets(func: Callable) -> Callable: # TODO inspect function to work out immediately if the wrong number of arguments were passed for it? @functools.wraps(func) - def _map_over_datasets(*args, **kwargs) -> DataTree | tuple[DataTree, ...]: + def _map_over_datasets(*args) -> DataTree | tuple[DataTree, ...]: """Internal function which maps func over every node in tree, returning a tree of the results.""" from xarray.core.datatree import DataTree - all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [ - a for a in kwargs.values() if isinstance(a, DataTree) - ] - - if len(all_tree_inputs) > 0: - first_tree, *other_trees = all_tree_inputs - else: - raise TypeError("Must pass at least one tree object") - - for other_tree in other_trees: - # isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic - check_isomorphic( - first_tree, other_tree, require_names_equal=False, check_from_root=False - ) - # Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees # We don't know which arguments are DataTrees so we zip all arguments together as iterables # Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return - out_data_objects = {} - args_as_tree_length_iterables = [ - a.subtree if isinstance(a, DataTree) else repeat(a) for a in args - ] - n_args = len(args_as_tree_length_iterables) - kwargs_as_tree_length_iterables = { - k: v.subtree if isinstance(v, DataTree) else repeat(v) - for k, v in kwargs.items() - } - for node_of_first_tree, *all_node_args in zip( - first_tree.subtree, - *args_as_tree_length_iterables, - *list(kwargs_as_tree_length_iterables.values()), - strict=False, - ): - node_args_as_datasetviews = [ - a.dataset if isinstance(a, DataTree) else a - for a in all_node_args[:n_args] - ] - node_kwargs_as_datasetviews = dict( - zip( - [k for k in kwargs_as_tree_length_iterables.keys()], - [ - v.dataset if isinstance(v, DataTree) else v - for v in all_node_args[n_args:] - ], - strict=True, - ) + out_data_objects: dict[str, Dataset | None | tuple[Dataset | None, ...]] = {} + + tree_args = [arg for arg in args if isinstance(arg, DataTree)] + subtrees = zip_subtrees(*tree_args) + + for node_tree_args in subtrees: + + node_dataset_args = [arg.dataset for arg in node_tree_args] + for i, arg in enumerate(args): + if not isinstance(arg, DataTree): + node_dataset_args.insert(i, arg) + + path = ( + "/" + if node_tree_args[0] is tree_args[0] + else node_tree_args[0].relative_to(tree_args[0]) ) - func_with_error_context = _handle_errors_with_path_context( - node_of_first_tree.path - )(func) - - if node_of_first_tree.has_data: - # call func on the data in this particular set of corresponding nodes - results = func_with_error_context( - *node_args_as_datasetviews, **node_kwargs_as_datasetviews - ) - elif node_of_first_tree.has_attrs: - # propagate attrs - results = node_of_first_tree.dataset - else: - # nothing to propagate so use fastpath to create empty node in new tree - results = None + func_with_error_context = _handle_errors_with_path_context(path)(func) + results = func_with_error_context(*node_dataset_args) - # TODO implement mapping over multiple trees in-place using if conditions from here on? - out_data_objects[node_of_first_tree.path] = results + out_data_objects[path] = results - # Find out how many return values we received num_return_values = _check_all_return_values(out_data_objects) - # Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees - original_root_path = first_tree.path - result_trees = [] - for i in range(num_return_values): - out_tree_contents = {} - for n in first_tree.subtree: - p = n.path - if p in out_data_objects.keys(): - if isinstance(out_data_objects[p], tuple): - output_node_data = out_data_objects[p][i] - else: - output_node_data = out_data_objects[p] - else: - output_node_data = None - - # Discard parentage so that new trees don't include parents of input nodes - relative_path = str(NodePath(p).relative_to(original_root_path)) - relative_path = "/" if relative_path == "." else relative_path - out_tree_contents[relative_path] = output_node_data - - new_tree = DataTree.from_dict( - out_tree_contents, - name=first_tree.name, - ) - result_trees.append(new_tree) + if num_return_values is None: + out_data = cast(Mapping[str, Dataset | None], out_data_objects) + return DataTree.from_dict(out_data, name=tree_args[0].name) - # If only one result then don't wrap it in a tuple - if len(result_trees) == 1: - return result_trees[0] - else: - return tuple(result_trees) + out_data_tuples = cast( + Mapping[str, tuple[Dataset | None, ...]], out_data_objects + ) + output_dicts: list[dict[str, Dataset | None]] = [ + {} for _ in range(num_return_values) + ] + for path, outputs in out_data_tuples.items(): + for output_dict, output in zip(output_dicts, outputs, strict=False): + output_dict[path] = output + + return tuple( + DataTree.from_dict(output_dict, name=tree_args[0].name) + for output_dict in output_dicts + ) return _map_over_datasets @@ -260,62 +203,54 @@ def add_note(err: BaseException, msg: str) -> None: err.add_note(msg) -def _check_single_set_return_values( - path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray] -): +def _check_single_set_return_values(path_to_node: str, obj: Any) -> int | None: """Check types returned from single evaluation of func, and return number of return values received from func.""" - if isinstance(obj, Dataset | DataArray): - return 1 - elif isinstance(obj, tuple): - for r in obj: - if not isinstance(r, Dataset | DataArray): - raise TypeError( - f"One of the results of calling func on datasets on the nodes at position {path_to_node} is " - f"of type {type(r)}, not Dataset or DataArray." - ) - return len(obj) - else: + if isinstance(obj, None | Dataset): + return None # no need to pack results + + if not isinstance(obj, tuple) or not all( + isinstance(r, Dataset | None) for r in obj + ): raise TypeError( - f"The result of calling func on the node at position {path_to_node} is of type {type(obj)}, not " - f"Dataset or DataArray, nor a tuple of such types." + f"the result of calling func on the node at position is not a Dataset or None " + f"or a tuple of such types: {obj!r}" ) + return len(obj) -def _check_all_return_values(returned_objects): - """Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types.""" - if all(r is None for r in returned_objects.values()): - raise TypeError( - "Called supplied function on all nodes but found a return value of None for" - "all of them." - ) +def _check_all_return_values(returned_objects) -> int | None: + """Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types.""" result_data_objects = [ - (path_to_node, r) - for path_to_node, r in returned_objects.items() - if r is not None + (path_to_node, r) for path_to_node, r in returned_objects.items() ] - if len(result_data_objects) == 1: - # Only one node in the tree: no need to check consistency of results between nodes - path_to_node, result = result_data_objects[0] - num_return_values = _check_single_set_return_values(path_to_node, result) - else: - prev_path, _ = result_data_objects[0] - prev_num_return_values, num_return_values = None, None - for path_to_node, obj in result_data_objects[1:]: - num_return_values = _check_single_set_return_values(path_to_node, obj) - - if ( - num_return_values != prev_num_return_values - and prev_num_return_values is not None - ): + first_path, result = result_data_objects[0] + return_values = _check_single_set_return_values(first_path, result) + + for path_to_node, obj in result_data_objects[1:]: + cur_return_values = _check_single_set_return_values(path_to_node, obj) + + if return_values != cur_return_values: + if return_values is None: raise TypeError( - f"Calling func on the nodes at position {path_to_node} returns {num_return_values} separate return " - f"values, whereas calling func on the nodes at position {prev_path} instead returns " - f"{prev_num_return_values} separate return values." + f"Calling func on the nodes at position {path_to_node} returns " + f"a tuple of {cur_return_values} datasets, whereas calling func on the " + f"nodes at position {first_path} instead returns a single dataset." + ) + elif cur_return_values is None: + raise TypeError( + f"Calling func on the nodes at position {path_to_node} returns " + f"a single dataset, whereas calling func on the nodes at position " + f"{first_path} instead returns a tuple of {return_values} datasets." + ) + else: + raise TypeError( + f"Calling func on the nodes at position {path_to_node} returns " + f"a tuple of {cur_return_values} datasets, whereas calling func on " + f"the nodes at position {first_path} instead returns a tuple of " + f"{return_values} datasets." ) - prev_path, prev_num_return_values = path_to_node, num_return_values - - return num_return_values + return return_values diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index a06e40b3688..c5bb08d4d86 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -791,6 +791,9 @@ def zip_subtrees(*trees: AnyNamedNode) -> Iterator[tuple[AnyNamedNode, ...]]: ------ Tuples of matching subtrees. """ + if not trees: + raise TypeError("Must pass at least one tree object") + # https://en.wikipedia.org/wiki/Breadth-first_search#Pseudocode queue = collections.deque([trees]) diff --git a/xarray/tests/test_datatree_mapping.py b/xarray/tests/test_datatree_mapping.py index 766df76a259..4fa608f4b7c 100644 --- a/xarray/tests/test_datatree_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -1,3 +1,5 @@ +import re + import numpy as np import pytest @@ -108,7 +110,10 @@ def test_not_isomorphic(self, create_test_datatree): def times_ten(ds1, ds2): return ds1 * ds2 - with pytest.raises(TreeIsomorphismError): + with pytest.raises( + ValueError, + match=re.escape(r"children at '/set1/set2' do not match: [] vs ['extra']"), + ): times_ten(dt1, dt2) def test_no_trees_returned(self, create_test_datatree): @@ -119,10 +124,11 @@ def test_no_trees_returned(self, create_test_datatree): def bad_func(ds1, ds2): return None - with pytest.raises(TypeError, match="return value of None"): - bad_func(dt1, dt2) + expected = xr.DataTree.from_dict({k: None for k in dt1.to_dict()}) + actual = bad_func(dt1, dt2) + assert_equal(expected, actual) - def test_single_dt_arg(self, create_test_datatree): + def test_single_tree_arg(self, create_test_datatree): dt = create_test_datatree() @map_over_datasets @@ -133,18 +139,21 @@ def times_ten(ds): result_tree = times_ten(dt) assert_equal(result_tree, expected) - def test_single_dt_arg_plus_args_and_kwargs(self, create_test_datatree): + def test_single_tree_arg_plus_arg(self, create_test_datatree): dt = create_test_datatree() @map_over_datasets - def multiply_then_add(ds, times, add=0.0): - return (times * ds) + add + def multiply(ds, times): + return times * ds - expected = create_test_datatree(modify=lambda ds: (10.0 * ds) + 2.0) - result_tree = multiply_then_add(dt, 10.0, add=2.0) + expected = create_test_datatree(modify=lambda ds: (10.0 * ds)) + result_tree = multiply(dt, 10.0) assert_equal(result_tree, expected) - def test_multiple_dt_args(self, create_test_datatree): + result_tree = multiply(10.0, dt) + assert_equal(result_tree, expected) + + def test_multiple_tree_args(self, create_test_datatree): dt1 = create_test_datatree() dt2 = create_test_datatree() @@ -156,19 +165,7 @@ def add(ds1, ds2): result = add(dt1, dt2) assert_equal(result, expected) - def test_dt_as_kwarg(self, create_test_datatree): - dt1 = create_test_datatree() - dt2 = create_test_datatree() - - @map_over_datasets - def add(ds1, value=0.0): - return ds1 + value - - expected = create_test_datatree(modify=lambda ds: 2.0 * ds) - result = add(dt1, value=dt2) - assert_equal(result, expected) - - def test_return_multiple_dts(self, create_test_datatree): + def test_return_multiple_trees(self, create_test_datatree): dt = create_test_datatree() @map_over_datasets @@ -188,7 +185,13 @@ def test_return_wrong_type(self, simple_datatree): def bad_func(ds1): return "string" - with pytest.raises(TypeError, match="not Dataset or DataArray"): + with pytest.raises( + TypeError, + match=re.escape( + "the result of calling func on the node at position is not a " + "Dataset or None or a tuple of such types" + ), + ): bad_func(dt1) def test_return_tuple_of_wrong_types(self, simple_datatree): @@ -198,7 +201,13 @@ def test_return_tuple_of_wrong_types(self, simple_datatree): def bad_func(ds1): return xr.Dataset(), "string" - with pytest.raises(TypeError, match="not Dataset or DataArray"): + with pytest.raises( + TypeError, + match=re.escape( + "the result of calling func on the node at position is not a " + "Dataset or None or a tuple of such types" + ), + ): bad_func(dt1) @pytest.mark.xfail @@ -243,14 +252,14 @@ def test_trees_with_different_node_names(self): # TODO test this after I've got good tests for renaming nodes raise NotImplementedError - def test_dt_method(self, create_test_datatree): + def test_tree_method(self, create_test_datatree): dt = create_test_datatree() - def multiply_then_add(ds, times, add=0.0): - return times * ds + add + def multiply(ds, times): + return times * ds - expected = create_test_datatree(modify=lambda ds: (10.0 * ds) + 2.0) - result_tree = dt.map_over_datasets(multiply_then_add, 10.0, add=2.0) + expected = create_test_datatree(modify=lambda ds: 10.0 * ds) + result_tree = dt.map_over_datasets(multiply, 10.0) assert_equal(result_tree, expected) def test_discard_ancestry(self, create_test_datatree): @@ -266,18 +275,6 @@ def times_ten(ds): result_tree = times_ten(subtree) assert_equal(result_tree, expected, from_root=False) - def test_skip_empty_nodes_with_attrs(self, create_test_datatree): - # inspired by xarray-datatree GH262 - dt = create_test_datatree() - dt["set1/set2"].attrs["foo"] = "bar" - - def check_for_data(ds): - # fails if run on a node that has no data - assert len(ds.variables) != 0 - return ds - - dt.map_over_datasets(check_for_data) - def test_keep_attrs_on_empty_nodes(self, create_test_datatree): # GH278 dt = create_test_datatree() @@ -336,6 +333,8 @@ def test_construct_using_type(self): dt = xr.DataTree.from_dict({"a": a, "b": b}) def weighted_mean(ds): + if "area" not in ds.coords: + return None return ds.weighted(ds.area).mean(["x", "y"]) dt.map_over_datasets(weighted_mean) @@ -359,7 +358,7 @@ def fast_forward(ds: xr.Dataset, years: float) -> xr.Dataset: return ds with pytest.raises(AttributeError): - simpsons.map_over_datasets(fast_forward, years=10) + simpsons.map_over_datasets(fast_forward, 10) @pytest.mark.xfail