You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I think there's some good opportunities to run map_over_subtree in parallel using dask.delayed.
Consider this example data:
importnumpyasnpimportxarrayasxrfromdatatreeimportDataTreenumber_of_files=25number_of_groups=20number_of_variables=2000datasets= {}
forfinrange(number_of_files):
forginrange(number_of_groups):
# Create random data:time=np.linspace(0, 50+f, 100+g)
y=f*time+g# Create dataset:ds=xr.Dataset(
data_vars={
f"temperature_{g}{i}": ("time", y)
foriinrange(number_of_variables//number_of_groups)
},
coords={"time": ("time", time)},
) # .chunk()# Prepare for Datatree:name=f"file_{f}/group_{g}"datasets[name] =dsdt=DataTree.from_dict(datasets)
# %% Interpolate to same time coordinatenew_time=np.linspace(0, 150, 50)
dt_interp=dt.interp(time=new_time)
# Original 10s, with dask.delayed 6s# If datasets were chunked: Original 34s, with dask.delayed 10s
Here's my modded map_over_subtree:
defmap_over_subtree(func: Callable) ->Callable:
""" Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees. Applies a function to every dataset in one or more subtrees, returning new trees which store the results. The function will be applied to any non-empty dataset stored in any of the nodes in the trees. The returned trees will have the same structure as the supplied trees. `func` needs to return one Datasets, DataArrays, or None in order to be able to rebuild the subtrees after mapping, as each result will be assigned to its respective node of a new tree via `DataTree.__setitem__`. Any returned value that is one of these types will be stacked into a separate tree before returning all of them. The trees passed to the resulting function must all be isomorphic to one another. Their nodes need not be named similarly, but all the output trees will have nodes named in the same way as the first tree passed. Parameters ---------- func : callable Function to apply to datasets with signature: `func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`. (i.e. func must accept at least one Dataset and return at least one Dataset.) Function will not be applied to any nodes without datasets. *args : tuple, optional Positional arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets via .ds . **kwargs : Any Keyword arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets via .ds . Returns ------- mapped : callable Wrapped function which returns one or more tree(s) created from results of applying ``func`` to the dataset at each node. See also -------- DataTree.map_over_subtree DataTree.map_over_subtree_inplace DataTree.subtree """# TODO examples in the docstring# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?@functools.wraps(func)def_map_over_subtree(*args, **kwargs) ->DataTree|Tuple[DataTree, ...]:
"""Internal function which maps func over every node in tree, returning a tree of the results."""from .datatreeimportDataTreeparallel=Trueifparallel:
importdaskfunc_=dask.delayed(func)
else:
func_=funcall_tree_inputs= [aforainargsifisinstance(a, DataTree)] + [
aforainkwargs.values() ifisinstance(a, DataTree)
]
iflen(all_tree_inputs) >0:
first_tree, *other_trees=all_tree_inputselse:
raiseTypeError("Must pass at least one tree object")
forother_treeinother_trees:
# isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphiccheck_isomorphic(
first_tree,
other_tree,
require_names_equal=False,
check_from_root=False,
)
# Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees# We don't know which arguments are DataTrees so we zip all arguments together as iterables# Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to returnout_data_objects= {}
args_as_tree_length_iterables= [
a.subtreeifisinstance(a, DataTree) elserepeat(a) forainargs
]
n_args=len(args_as_tree_length_iterables)
kwargs_as_tree_length_iterables= {
k: v.subtreeifisinstance(v, DataTree) elserepeat(v)
fork, vinkwargs.items()
}
fornode_of_first_tree, *all_node_argsinzip(
first_tree.subtree,
*args_as_tree_length_iterables,
*list(kwargs_as_tree_length_iterables.values()),
):
node_args_as_datasets= [
a.to_dataset() ifisinstance(a, DataTree) elseaforainall_node_args[:n_args]
]
node_kwargs_as_datasets=dict(
zip(
[kforkinkwargs_as_tree_length_iterables.keys()],
[
v.to_dataset() ifisinstance(v, DataTree) elsevforvinall_node_args[n_args:]
],
)
)
# Now we can call func on the data in this particular set of corresponding nodesresults= (
func_(*node_args_as_datasets, **node_kwargs_as_datasets)
ifnotnode_of_first_tree.is_emptyelseNone
)
# TODO implement mapping over multiple trees in-place using if conditions from here on?out_data_objects[node_of_first_tree.path] =resultsifparallel:
keys, values=dask.compute(
[kforkinout_data_objects.keys()],
[vforvinout_data_objects.values()],
)
out_data_objects= {k: vfork, vinzip(keys, values)}
# Find out how many return values we receivednum_return_values=_check_all_return_values(out_data_objects)
# Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result treesoriginal_root_path=first_tree.pathresult_trees= []
foriinrange(num_return_values):
out_tree_contents= {}
forninfirst_tree.subtree:
p=n.pathifpinout_data_objects.keys():
ifisinstance(out_data_objects[p], tuple):
output_node_data=out_data_objects[p][i]
else:
output_node_data=out_data_objects[p]
else:
output_node_data=None# Discard parentage so that new trees don't include parents of input nodesrelative_path=str(
NodePath(p).relative_to(original_root_path)
)
relative_path="/"ifrelative_path=="."elserelative_pathout_tree_contents[relative_path] =output_node_datanew_tree=DataTree.from_dict(
out_tree_contents,
name=first_tree.name,
)
result_trees.append(new_tree)
# If only one result then don't wrap it in a tupleiflen(result_trees) ==1:
returnresult_trees[0]
else:
returntuple(result_trees)
return_map_over_subtree
I'm a little unsure how to get the parallel-argument down to map_over_subtree though?
The text was updated successfully, but these errors were encountered:
Copied from xarray-contrib/datatree#252
What is your issue?
I think there's some good opportunities to run
map_over_subtree
in parallel usingdask.delayed
.Consider this example data:
Here's my modded
map_over_subtree
:I'm a little unsure how to get the parallel-argument down to
map_over_subtree
though?The text was updated successfully, but these errors were encountered: