From a96e7598b279499ac9e58360184b6603351e9ffc Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 2 Mar 2020 17:41:16 +0530 Subject: [PATCH 01/23] MVP for dask collections in args --- xarray/core/parallel.py | 185 ++++++++++++++++++++++++-------------- xarray/tests/test_dask.py | 3 - 2 files changed, 117 insertions(+), 71 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index d91dfb4a275..a8301c28458 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -55,6 +55,61 @@ def check_result_variables( ) +def subset_dataset_to_chunk(graph, gname, dataset, input_chunks, chunk_tuple): + + # mapping from dimension name to chunk index + input_chunk_index = dict(zip(input_chunks.keys(), chunk_tuple)) + + # mapping from chunk index to slice bounds + chunk_index_bounds = { + dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in input_chunks.items() + } + + # this will become [[name1, variable1], + # [name2, variable2], + # ...] + # which is passed to dict and then to Dataset + data_vars = [] + coords = [] + + for name, variable in dataset.variables.items(): + # make a task that creates tuple of (dims, chunk) + if dask.is_dask_collection(variable.data): + # recursively index into dask_keys nested list to get chunk + chunk = variable.__dask_keys__() + for dim in variable.dims: + chunk = chunk[chunk_index[dim]] + + chunk_variable_task = (f"{gname}-{name}-{chunk[0]}",) + v + graph[chunk_variable_task] = ( + tuple, + [variable.dims, chunk, variable.attrs], + ) + else: + # non-dask array with possibly chunked dimensions + # index into variable appropriately + subsetter = { + dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) + for dim in variable.dims + } + subset = variable.isel(subsetter) + chunk_variable_task = ( + "{}-{}".format(gname, dask.base.tokenize(subset)), + ) + v + graph[chunk_variable_task] = ( + tuple, + [subset.dims, subset, subset.attrs], + ) + + # this task creates dict mapping variable name to above tuple + if name in dataset._coord_names: + coords.append([name, chunk_variable_task]) + else: + data_vars.append([name, chunk_variable_task]) + + return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) + + def dataset_to_dataarray(obj: Dataset) -> DataArray: if not isinstance(obj, Dataset): raise TypeError("Expected Dataset, got %s" % type(obj)) @@ -67,6 +122,17 @@ def dataset_to_dataarray(obj: Dataset) -> DataArray: return next(iter(obj.data_vars.values())) +def dataarray_to_dataset(obj: DataArray) -> Dataset: + # only using _to_temp_dataset would break + # func = lambda x: x.to_dataset() + # since that relies on preserving name. + if obj.name is None: + dataset = obj._to_temp_dataset() + else: + dataset = obj.to_dataset() + return dataset + + def make_meta(obj): """If obj is a DataArray or Dataset, return a new object of the same type and with the same variables and dtypes, but where all variables have size 0 and numpy @@ -161,8 +227,8 @@ def map_blocks( obj: DataArray, Dataset Passed to the function as its first argument, one dask chunk at a time. args: Sequence - Passed verbatim to func after unpacking, after the sliced obj. xarray objects, - if any, will not be split by chunks. Passing dask collections is not allowed. + Passed verbatim to func after unpacking, after the sliced obj. + Any xarray objects will also be split by chunks and then passed on. kwargs: Mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be split by chunks. Passing dask collections is not allowed. @@ -241,14 +307,16 @@ def map_blocks( * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 """ - def _wrapper(func, obj, to_array, args, kwargs, expected): - check_shapes = dict(obj.dims) + def _wrapper(func, args, kwargs, arg_is_array, expected): + check_shapes = dict(args[0].dims) check_shapes.update(expected["shapes"]) - if to_array: - obj = dataset_to_dataarray(obj) + converted_args = [ + dataset_to_dataarray(arg) if is_array else arg + for is_array, arg in zip(arg_is_array, args) + ] - result = func(obj, *args, **kwargs) + result = func(*converted_args, **kwargs) # check all dims are present missing_dimensions = set(expected["shapes"]) - set(result.sizes) @@ -289,10 +357,10 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): elif not isinstance(kwargs, Mapping): raise TypeError("kwargs must be a mapping (for example, a dict)") - for value in list(args) + list(kwargs.values()): + for value in kwargs.values(): if dask.is_dask_collection(value): raise TypeError( - "Cannot pass dask collections in args or kwargs yet. Please compute or " + "Cannot pass dask collections in kwargs yet. Please compute or " "load values before passing to map_blocks." ) @@ -300,27 +368,35 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): return func(obj, *args, **kwargs) if isinstance(obj, DataArray): - # only using _to_temp_dataset would break - # func = lambda x: x.to_dataset() - # since that relies on preserving name. - if obj.name is None: - dataset = obj._to_temp_dataset() - else: - dataset = obj.to_dataset() + dataset = dataarray_to_dataset(obj) input_is_array = True else: dataset = obj input_is_array = False - input_chunks = dataset.chunks - dataset_indexes = set(dataset.indexes) + # TODO: align args and dataset here? + # TODO: unify_chunks for args and dataset + input_chunks = dict(dataset.chunks) + input_indexes = dict(dataset.indexes) + converted_args = [] + arg_is_array = [] + for arg in args: + arg_is_array.append(isinstance(arg, DataArray)) + if isinstance(arg, (Dataset, DataArray)): + if isinstance(arg, DataArray): + converted_args.append(dataarray_to_dataset(arg)) + input_chunks.update(converted_args[-1].chunks) + input_indexes.update(converted_args[-1].indexes) + else: + converted_args.append(arg) + if template is None: # infer template by providing zero-shaped arrays template = infer_template(func, obj, *args, **kwargs) template_indexes = set(template.indexes) - preserved_indexes = template_indexes & dataset_indexes - new_indexes = template_indexes - dataset_indexes - indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes} + preserved_indexes = template_indexes & set(input_indexes) + new_indexes = template_indexes - set(input_indexes) + indexes = {dim: input_indexes[dim] for dim in preserved_indexes} indexes.update({k: template.indexes[k] for k in new_indexes}) output_chunks = { dim: input_chunks[dim] for dim in template.dims if dim in input_chunks @@ -329,7 +405,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): else: # template xarray object has been provided with proper sizes and chunk shapes template_indexes = set(template.indexes) - indexes = {dim: dataset.indexes[dim] for dim in dataset_indexes} + indexes = input_indexes indexes.update({k: template.indexes[k] for k in template_indexes}) if isinstance(template, DataArray): output_chunks = dict(zip(template.dims, template.chunks)) # type: ignore @@ -377,50 +453,20 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): } # iterate over all possible chunk combinations - for v in itertools.product(*ichunk.values()): - chunk_index = dict(zip(dataset.dims, v)) - - # this will become [[name1, variable1], - # [name2, variable2], - # ...] - # which is passed to dict and then to Dataset - data_vars = [] - coords = [] - - for name, variable in dataset.variables.items(): - # make a task that creates tuple of (dims, chunk) - if dask.is_dask_collection(variable.data): - # recursively index into dask_keys nested list to get chunk - chunk = variable.__dask_keys__() - for dim in variable.dims: - chunk = chunk[chunk_index[dim]] - - chunk_variable_task = (f"{gname}-{name}-{chunk[0]}",) + v - graph[chunk_variable_task] = ( - tuple, - [variable.dims, chunk, variable.attrs], - ) - else: - # non-dask array with possibly chunked dimensions - # index into variable appropriately - subsetter = { - dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) - for dim in variable.dims - } - subset = variable.isel(subsetter) - chunk_variable_task = ( - "{}-{}".format(gname, dask.base.tokenize(subset)), - ) + v - graph[chunk_variable_task] = ( - tuple, - [subset.dims, subset, subset.attrs], + for chunk_tuple in itertools.product(*ichunk.values()): + # mapping from dimension name to chunk index + input_chunk_index = dict(zip(input_chunks.keys(), chunk_tuple)) + + chunked_args = [] + for arg in (dataset,) + tuple(converted_args): + if isinstance(arg, (DataArray, Dataset)): + chunked_args.append( + subset_dataset_to_chunk( + graph, gname, arg, input_chunks, chunk_tuple + ) ) - - # this task creates dict mapping variable name to above tuple - if name in dataset._coord_names: - coords.append([name, chunk_variable_task]) else: - data_vars.append([name, chunk_variable_task]) + chunked_args.append(arg) # expected["shapes", "coords", "data_vars", "indexes"] are used to raise nice error messages in _wrapper expected = {} @@ -440,10 +486,9 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): graph[from_wrapper] = ( _wrapper, func, - (Dataset, (dict, data_vars), (dict, coords), dataset.attrs), - input_is_array, - args, + chunked_args, kwargs, + [input_is_array] + arg_is_array, expected, ) @@ -472,7 +517,11 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): # layer. new_layers[gname_l][key] = (operator.getitem, from_wrapper, name) - hlg = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset]) + hlg = HighLevelGraph.from_collections( + gname, + graph, + dependencies=[dataset] + [arg for arg in args if dask.is_dask_collection(arg)], + ) for gname_l, layer in new_layers.items(): # This adds in the getitems for each variable in the dataset. diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 75beb3757ca..0372c223172 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1066,9 +1066,6 @@ def really_bad_func(darray): with raises_regex(ValueError, "inconsistent chunks"): xr.map_blocks(bad_func, ds_copy) - with raises_regex(TypeError, "Cannot pass dask collections"): - xr.map_blocks(bad_func, map_da, args=[map_da.chunk()]) - with raises_regex(TypeError, "Cannot pass dask collections"): xr.map_blocks(bad_func, map_da, kwargs=dict(a=map_da.chunk())) From 5c17e3f63fea26664b43436f4c4c98622b546eea Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 3 Mar 2020 23:40:16 +0530 Subject: [PATCH 02/23] Add tests. --- xarray/core/parallel.py | 2 +- xarray/tests/test_dask.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index a8301c28458..b46acca0ba4 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -410,7 +410,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected): if isinstance(template, DataArray): output_chunks = dict(zip(template.dims, template.chunks)) # type: ignore else: - output_chunks = template.chunks # type: ignore + output_chunks = dict(template.chunks) for dim in output_chunks: if dim in input_chunks and len(input_chunks[dim]) != len(output_chunks[dim]): diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 0372c223172..abba34275ac 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1092,6 +1092,36 @@ def test_map_blocks_convert_args_to_list(obj): assert_identical(actual, expected) +def test_map_blocks_dask_args(): + da1 = xr.DataArray( + np.ones((10, 20)), + dims=["x", "y"], + coords={"x": np.arange(10), "y": np.arange(20)}, + ).chunk({"x": 5, "y": 4}) + + # check that block shapes are the same + def sumda(da1, da2): + assert da1.shape == da2.shape + return da1 + da2 + + da2 = da1 + 1 + with raise_if_dask_computes(): + mapped = xr.map_blocks(sumda, da1, args=[da2]) + xr.testing.assert_equal(da1 + da2, mapped) + + # one dimension in common + da2 = (da1 + 1).isel(x=1, drop=True) + with raise_if_dask_computes(): + mapped = xr.map_blocks(operator.add, da1, args=[da2]) + xr.testing.assert_equal(da1 + da2, mapped) + + # test that everything works when dimension names are different + da2 = (da1 + 1).isel(x=1, drop=True).rename({"y": "k"}) + with raise_if_dask_computes(): + mapped = xr.map_blocks(operator.add, da1, args=[da2]) + xr.testing.assert_equal(da1 + da2, mapped) + + @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_add_attrs(obj): def add_attrs(obj): From 74cb1189d5f36f1203de82d027f7282f1432b378 Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 9 Mar 2020 15:05:05 +0530 Subject: [PATCH 03/23] Use list comprehension --- xarray/core/parallel.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index b46acca0ba4..8be88e5ca65 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -55,7 +55,7 @@ def check_result_variables( ) -def subset_dataset_to_chunk(graph, gname, dataset, input_chunks, chunk_tuple): +def subset_dataset_to_block(graph, gname, dataset, input_chunks, chunk_tuple): # mapping from dimension name to chunk index input_chunk_index = dict(zip(input_chunks.keys(), chunk_tuple)) @@ -457,16 +457,12 @@ def _wrapper(func, args, kwargs, arg_is_array, expected): # mapping from dimension name to chunk index input_chunk_index = dict(zip(input_chunks.keys(), chunk_tuple)) - chunked_args = [] - for arg in (dataset,) + tuple(converted_args): - if isinstance(arg, (DataArray, Dataset)): - chunked_args.append( - subset_dataset_to_chunk( - graph, gname, arg, input_chunks, chunk_tuple - ) - ) - else: - chunked_args.append(arg) + blocked_args = [ + subset_dataset_to_block(graph, gname, arg, input_chunks, chunk_tuple) + if isinstance(arg, (DataArray, Dataset)) + else arg + for arg in (dataset,) + tuple(converted_args) + ] # expected["shapes", "coords", "data_vars", "indexes"] are used to raise nice error messages in _wrapper expected = {} @@ -486,7 +482,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected): graph[from_wrapper] = ( _wrapper, func, - chunked_args, + blocked_args, kwargs, [input_is_array] + arg_is_array, expected, From 639ad6902f7ab2e2fc3dc3821f8be9cd7ecae945 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 21 Mar 2020 18:42:09 -0600 Subject: [PATCH 04/23] map_blocks: preserve attrs of dimension coordinates in input Switch to use IndexVariables instead of Indexes so that attrs are preserved. --- xarray/core/parallel.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 8be88e5ca65..42eb9c8558c 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -31,6 +31,10 @@ T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) +def get_index_vars(obj): + return {dim: obj[dim] for dim in obj.indexes} + + def check_result_variables( result: Union[DataArray, Dataset], expected: Mapping[str, Any], kind: str ): @@ -377,7 +381,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected): # TODO: align args and dataset here? # TODO: unify_chunks for args and dataset input_chunks = dict(dataset.chunks) - input_indexes = dict(dataset.indexes) + input_indexes = get_index_vars(dataset) converted_args = [] arg_is_array = [] for arg in args: @@ -397,16 +401,15 @@ def _wrapper(func, args, kwargs, arg_is_array, expected): preserved_indexes = template_indexes & set(input_indexes) new_indexes = template_indexes - set(input_indexes) indexes = {dim: input_indexes[dim] for dim in preserved_indexes} - indexes.update({k: template.indexes[k] for k in new_indexes}) + indexes.update({k: template[k] for k in new_indexes}) output_chunks = { dim: input_chunks[dim] for dim in template.dims if dim in input_chunks } else: # template xarray object has been provided with proper sizes and chunk shapes - template_indexes = set(template.indexes) indexes = input_indexes - indexes.update({k: template.indexes[k] for k in template_indexes}) + indexes.update(get_index_vars(template)) if isinstance(template, DataArray): output_chunks = dict(zip(template.dims, template.chunks)) # type: ignore else: From 748828e3fa3fcc960d482cf32ad7d2a4c2fbeab4 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 28 Mar 2020 09:12:25 -0600 Subject: [PATCH 05/23] Check that chunk sizes are compatible. --- xarray/core/parallel.py | 11 ++++++++++- xarray/tests/test_dask.py | 3 +++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 42eb9c8558c..17aa8562413 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -35,6 +35,15 @@ def get_index_vars(obj): return {dim: obj[dim] for dim in obj.indexes} +def assert_chunks_compatible(a: Dataset, b: Dataset): + a = a.unify_chunks() + b = b.unify_chunks() + + for dim in set(a.chunks).intersection(set(b.chunks)): + if a.chunks[dim] != b.chunks[dim]: + raise ValueError(f"Chunk sizes along dimension {dim!r} are not equal.") + + def check_result_variables( result: Union[DataArray, Dataset], expected: Mapping[str, Any], kind: str ): @@ -379,7 +388,6 @@ def _wrapper(func, args, kwargs, arg_is_array, expected): input_is_array = False # TODO: align args and dataset here? - # TODO: unify_chunks for args and dataset input_chunks = dict(dataset.chunks) input_indexes = get_index_vars(dataset) converted_args = [] @@ -389,6 +397,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected): if isinstance(arg, (Dataset, DataArray)): if isinstance(arg, DataArray): converted_args.append(dataarray_to_dataset(arg)) + assert_chunks_compatible(dataset, converted_args[-1]) input_chunks.update(converted_args[-1].chunks) input_indexes.update(converted_args[-1].indexes) else: diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index abba34275ac..d5819624118 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1121,6 +1121,9 @@ def sumda(da1, da2): mapped = xr.map_blocks(operator.add, da1, args=[da2]) xr.testing.assert_equal(da1 + da2, mapped) + with raises_regex(ValueError, "Chunk sizes along dimension 'x'"): + xr.map_blocks(operator.add, da1, args=[da1.chunk({"x": 1})]) + @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_add_attrs(obj): From f73491a7f7d3f54fa91bc8277367851402a2681d Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 28 Mar 2020 09:26:13 -0600 Subject: [PATCH 06/23] Align all xarray objects --- xarray/core/parallel.py | 71 +++++++++++++++++++-------------------- xarray/tests/test_dask.py | 4 +++ 2 files changed, 39 insertions(+), 36 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 17aa8562413..f174cca0afc 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -25,6 +25,7 @@ import numpy as np +from .alignment import align from .dataarray import DataArray from .dataset import Dataset @@ -35,6 +36,13 @@ def get_index_vars(obj): return {dim: obj[dim] for dim in obj.indexes} +def to_object_array(iterable): + npargs = np.empty((len(iterable),), dtype=np.object) + for idx, item in enumerate(iterable): + npargs[idx] = item + return npargs + + def assert_chunks_compatible(a: Dataset, b: Dataset): a = a.unify_chunks() b = b.unify_chunks() @@ -380,32 +388,30 @@ def _wrapper(func, args, kwargs, arg_is_array, expected): if not dask.is_dask_collection(obj): return func(obj, *args, **kwargs) - if isinstance(obj, DataArray): - dataset = dataarray_to_dataset(obj) - input_is_array = True - else: - dataset = obj - input_is_array = False - - # TODO: align args and dataset here? - input_chunks = dict(dataset.chunks) - input_indexes = get_index_vars(dataset) - converted_args = [] - arg_is_array = [] - for arg in args: - arg_is_array.append(isinstance(arg, DataArray)) - if isinstance(arg, (Dataset, DataArray)): - if isinstance(arg, DataArray): - converted_args.append(dataarray_to_dataset(arg)) - assert_chunks_compatible(dataset, converted_args[-1]) - input_chunks.update(converted_args[-1].chunks) - input_indexes.update(converted_args[-1].indexes) - else: - converted_args.append(arg) + npargs = to_object_array([obj] + list(args)) + is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in npargs] + is_array = [isinstance(arg, DataArray) for arg in npargs] + + # align all xarray objects + # TODO: should we allow join as a kwarg or force everything to be aligned to the first object? + aligned = align(*npargs[is_xarray], join="left") + # assigning to object arrays works better when RHS is object array + # https://stackoverflow.com/questions/43645135/boolean-indexing-assignment-of-a-numpy-array-to-a-numpy-array + npargs[is_xarray] = to_object_array(aligned) + npargs[is_array] = to_object_array( + [dataarray_to_dataset(da) for da in npargs[is_array]] + ) + + input_chunks = dict(npargs[0].chunks) + input_indexes = get_index_vars(npargs[0]) + for arg in npargs[1:][is_xarray[1:]]: + assert_chunks_compatible(npargs[0], arg) + input_chunks.update(arg.chunks) + input_indexes.update(arg.indexes) if template is None: # infer template by providing zero-shaped arrays - template = infer_template(func, obj, *args, **kwargs) + template = infer_template(func, aligned[0], *args, **kwargs) template_indexes = set(template.indexes) preserved_indexes = template_indexes & set(input_indexes) new_indexes = template_indexes - set(input_indexes) @@ -451,7 +457,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected): graph: Dict[Any, Any] = {} new_layers: DefaultDict[str, Dict[Any, Any]] = collections.defaultdict(dict) gname = "{}-{}".format( - dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs) + dask.utils.funcname(func), dask.base.tokenize(npargs[0], args, kwargs) ) # map dims to list of chunk indexes @@ -471,9 +477,9 @@ def _wrapper(func, args, kwargs, arg_is_array, expected): blocked_args = [ subset_dataset_to_block(graph, gname, arg, input_chunks, chunk_tuple) - if isinstance(arg, (DataArray, Dataset)) + if isxr else arg - for arg in (dataset,) + tuple(converted_args) + for isxr, arg in zip(is_xarray, npargs) ] # expected["shapes", "coords", "data_vars", "indexes"] are used to raise nice error messages in _wrapper @@ -490,15 +496,8 @@ def _wrapper(func, args, kwargs, arg_is_array, expected): for dim in indexes } - from_wrapper = (gname,) + v - graph[from_wrapper] = ( - _wrapper, - func, - blocked_args, - kwargs, - [input_is_array] + arg_is_array, - expected, - ) + from_wrapper = (gname,) + chunk_tuple + graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected) # mapping from variable name to dask graph key var_key_map: Dict[Hashable, str] = {} @@ -528,7 +527,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected): hlg = HighLevelGraph.from_collections( gname, graph, - dependencies=[dataset] + [arg for arg in args if dask.is_dask_collection(arg)], + dependencies=[arg for arg in npargs if dask.is_dask_collection(arg)], ) for gname_l, layer in new_layers.items(): diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index d5819624118..63234a5319e 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1124,6 +1124,10 @@ def sumda(da1, da2): with raises_regex(ValueError, "Chunk sizes along dimension 'x'"): xr.map_blocks(operator.add, da1, args=[da1.chunk({"x": 1})]) + with raise_if_dask_computes(): + mapped = xr.map_blocks(operator.add, da1, args=[da1.reindex(x=np.arange(20))]) + xr.testing.assert_equal(da1 + da1, mapped) + @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_add_attrs(obj): From 402964930b5116f3a511cb759d2196500c57a521 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 28 Mar 2020 13:18:41 -0600 Subject: [PATCH 07/23] Add some type hints. --- xarray/core/parallel.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index f174cca0afc..c837455135f 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -16,6 +16,8 @@ DefaultDict, Dict, Hashable, + Iterable, + List, Mapping, Sequence, Tuple, @@ -32,7 +34,7 @@ T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) -def get_index_vars(obj): +def get_index_vars(obj: Union[DataArray, Dataset]) -> dict: return {dim: obj[dim] for dim in obj.indexes} @@ -76,7 +78,9 @@ def check_result_variables( ) -def subset_dataset_to_block(graph, gname, dataset, input_chunks, chunk_tuple): +def subset_dataset_to_block( + graph: dict, gname: str, dataset: Dataset, input_chunks: dict, chunk_tuple: tuple +): # mapping from dimension name to chunk index input_chunk_index = dict(zip(input_chunks.keys(), chunk_tuple)) @@ -328,7 +332,13 @@ def map_blocks( * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 """ - def _wrapper(func, args, kwargs, arg_is_array, expected): + def _wrapper( + func: Callable, + args: List, + kwargs: dict, + arg_is_array: Iterable[bool], + expected: dict, + ): check_shapes = dict(args[0].dims) check_shapes.update(expected["shapes"]) From d2f291695c010ea3a2965880384291921b416f96 Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 6 May 2020 11:37:45 -0600 Subject: [PATCH 08/23] fix rebase --- xarray/core/parallel.py | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index c837455135f..fc7a10b8adb 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -79,16 +79,16 @@ def check_result_variables( def subset_dataset_to_block( - graph: dict, gname: str, dataset: Dataset, input_chunks: dict, chunk_tuple: tuple + graph: dict, + gname: str, + dataset: Dataset, + input_chunks: dict, + input_chunk_bounds, + chunk_tuple: tuple, ): # mapping from dimension name to chunk index - input_chunk_index = dict(zip(input_chunks.keys(), chunk_tuple)) - - # mapping from chunk index to slice bounds - chunk_index_bounds = { - dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in input_chunks.items() - } + chunk_index = dict(zip(dataset.dims, chunk_tuple)) # this will become [[name1, variable1], # [name2, variable2], @@ -105,11 +105,8 @@ def subset_dataset_to_block( for dim in variable.dims: chunk = chunk[chunk_index[dim]] - chunk_variable_task = (f"{gname}-{name}-{chunk[0]}",) + v - graph[chunk_variable_task] = ( - tuple, - [variable.dims, chunk, variable.attrs], - ) + chunk_variable_task = (f"{gname}-{name}-{chunk[0]}",) + chunk_tuple + graph[chunk_variable_task] = (tuple, [variable.dims, chunk, variable.attrs]) else: # non-dask array with possibly chunked dimensions # index into variable appropriately @@ -120,11 +117,8 @@ def subset_dataset_to_block( subset = variable.isel(subsetter) chunk_variable_task = ( "{}-{}".format(gname, dask.base.tokenize(subset)), - ) + v - graph[chunk_variable_task] = ( - tuple, - [subset.dims, subset, subset.attrs], - ) + ) + chunk_tuple + graph[chunk_variable_task] = (tuple, [subset.dims, subset, subset.attrs]) # this task creates dict mapping variable name to above tuple if name in dataset._coord_names: @@ -357,7 +351,7 @@ def _wrapper( ) # check that index lengths and values are as expected - for name, index in result.indexes.items(): + for name, index in get_index_vars(result).items(): if name in check_shapes: if len(index) != check_shapes[name]: raise ValueError( @@ -483,10 +477,12 @@ def _wrapper( # iterate over all possible chunk combinations for chunk_tuple in itertools.product(*ichunk.values()): # mapping from dimension name to chunk index - input_chunk_index = dict(zip(input_chunks.keys(), chunk_tuple)) + chunk_index = dict(zip(input_chunks.keys(), chunk_tuple)) blocked_args = [ - subset_dataset_to_block(graph, gname, arg, input_chunks, chunk_tuple) + subset_dataset_to_block( + graph, gname, arg, input_chunks, input_chunk_bounds, chunk_tuple + ) if isxr else arg for isxr, arg in zip(is_xarray, npargs) From 32a37c7701f755a26c702320d9097523102d790c Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 6 May 2020 12:02:18 -0600 Subject: [PATCH 09/23] move _wrapper out --- xarray/core/parallel.py | 94 +++++++++++++++++++++-------------------- 1 file changed, 49 insertions(+), 45 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index fc7a10b8adb..4a5c7990f9b 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -219,6 +219,55 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping return slice(None) +def _wrapper( + func: Callable, + args: List, + kwargs: dict, + arg_is_array: Iterable[bool], + expected: dict, +): + """ + Wrapper function that receives datasets in args; converts to dataarrays when necessary; + passes these to the user function `func` and checks returned objects for expected shapes/sizes/etc. + """ + + check_shapes = dict(args[0].dims) + check_shapes.update(expected["shapes"]) + + converted_args = [ + dataset_to_dataarray(arg) if is_array else arg + for is_array, arg in zip(arg_is_array, args) + ] + + result = func(*converted_args, **kwargs) + + # check all dims are present + missing_dimensions = set(expected["shapes"]) - set(result.sizes) + if missing_dimensions: + raise ValueError(f"Dimensions {missing_dimensions} missing on returned object.") + + # check that index lengths and values are as expected + for name, index in get_index_vars(result).items(): + if name in check_shapes: + if len(index) != check_shapes[name]: + raise ValueError( + f"Received dimension {name!r} of length {len(index)}. Expected length {check_shapes[name]}." + ) + if name in expected["indexes"]: + expected_index = expected["indexes"][name] + if not index.equals(expected_index): + raise ValueError( + f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead." + ) + + # check that all expected variables were returned + check_result_variables(result, expected, "coords") + if isinstance(result, Dataset): + check_result_variables(result, expected, "data_vars") + + return make_dict(result) + + def map_blocks( func: Callable[..., T_DSorDA], obj: Union[DataArray, Dataset], @@ -326,51 +375,6 @@ def map_blocks( * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 """ - def _wrapper( - func: Callable, - args: List, - kwargs: dict, - arg_is_array: Iterable[bool], - expected: dict, - ): - check_shapes = dict(args[0].dims) - check_shapes.update(expected["shapes"]) - - converted_args = [ - dataset_to_dataarray(arg) if is_array else arg - for is_array, arg in zip(arg_is_array, args) - ] - - result = func(*converted_args, **kwargs) - - # check all dims are present - missing_dimensions = set(expected["shapes"]) - set(result.sizes) - if missing_dimensions: - raise ValueError( - f"Dimensions {missing_dimensions} missing on returned object." - ) - - # check that index lengths and values are as expected - for name, index in get_index_vars(result).items(): - if name in check_shapes: - if len(index) != check_shapes[name]: - raise ValueError( - f"Received dimension {name!r} of length {len(index)}. Expected length {check_shapes[name]}." - ) - if name in expected["indexes"]: - expected_index = expected["indexes"][name] - if not index.equals(expected_index): - raise ValueError( - f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead." - ) - - # check that all expected variables were returned - check_result_variables(result, expected, "coords") - if isinstance(result, Dataset): - check_result_variables(result, expected, "data_vars") - - return make_dict(result) - if template is not None and not isinstance(template, (DataArray, Dataset)): raise TypeError( f"template must be a DataArray or Dataset. Received {type(template).__name__} instead." From a0e699fe9a9752288c2b365d27b9fddc677ddf79 Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 6 May 2020 12:31:15 -0600 Subject: [PATCH 10/23] Fixes --- xarray/core/parallel.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 4a5c7990f9b..c2473b9e3b2 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -79,24 +79,24 @@ def check_result_variables( def subset_dataset_to_block( - graph: dict, - gname: str, - dataset: Dataset, - input_chunks: dict, - input_chunk_bounds, - chunk_tuple: tuple, + graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index ): + """ + Creates a task that creates a subsets xarray dataset to a block determined by chunk_index; + whose extents are determined by input_chunk_bounds. + There are subtasks that create subsets of constituent variables. - # mapping from dimension name to chunk index - chunk_index = dict(zip(dataset.dims, chunk_tuple)) + TODO: This is modifying graph in-place! + """ # this will become [[name1, variable1], - # [name2, variable2], - # ...] + # [name2, variable2], + # ...] # which is passed to dict and then to Dataset data_vars = [] coords = [] + chunk_tuple = tuple(chunk_index.values()) for name, variable in dataset.variables.items(): # make a task that creates tuple of (dims, chunk) if dask.is_dask_collection(variable.data): @@ -410,12 +410,13 @@ def map_blocks( [dataarray_to_dataset(da) for da in npargs[is_array]] ) + # check that chunk sizes are compatible input_chunks = dict(npargs[0].chunks) input_indexes = get_index_vars(npargs[0]) for arg in npargs[1:][is_xarray[1:]]: assert_chunks_compatible(npargs[0], arg) input_chunks.update(arg.chunks) - input_indexes.update(arg.indexes) + input_indexes.update(get_index_vars(arg)) if template is None: # infer template by providing zero-shaped arrays @@ -481,12 +482,10 @@ def map_blocks( # iterate over all possible chunk combinations for chunk_tuple in itertools.product(*ichunk.values()): # mapping from dimension name to chunk index - chunk_index = dict(zip(input_chunks.keys(), chunk_tuple)) + chunk_index = dict(zip(ichunk.keys(), chunk_tuple)) blocked_args = [ - subset_dataset_to_block( - graph, gname, arg, input_chunks, input_chunk_bounds, chunk_tuple - ) + subset_dataset_to_block(graph, gname, arg, input_chunk_bounds, chunk_index) if isxr else arg for isxr, arg in zip(is_xarray, npargs) From 4d40a2548f358e1d1ec010b38dc42ff0bebfc72d Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 6 May 2020 13:32:16 -0600 Subject: [PATCH 11/23] avoid index dataarrays for simplicity. need a solution to preserve index attrs --- xarray/core/parallel.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index c2473b9e3b2..c2b09c119ad 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -34,10 +34,6 @@ T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) -def get_index_vars(obj: Union[DataArray, Dataset]) -> dict: - return {dim: obj[dim] for dim in obj.indexes} - - def to_object_array(iterable): npargs = np.empty((len(iterable),), dtype=np.object) for idx, item in enumerate(iterable): @@ -247,7 +243,7 @@ def _wrapper( raise ValueError(f"Dimensions {missing_dimensions} missing on returned object.") # check that index lengths and values are as expected - for name, index in get_index_vars(result).items(): + for name, index in result.indexes.items(): if name in check_shapes: if len(index) != check_shapes[name]: raise ValueError( @@ -412,11 +408,11 @@ def map_blocks( # check that chunk sizes are compatible input_chunks = dict(npargs[0].chunks) - input_indexes = get_index_vars(npargs[0]) + input_indexes = dict(npargs[0].indexes) for arg in npargs[1:][is_xarray[1:]]: assert_chunks_compatible(npargs[0], arg) input_chunks.update(arg.chunks) - input_indexes.update(get_index_vars(arg)) + input_indexes.update(arg.indexes) if template is None: # infer template by providing zero-shaped arrays @@ -425,7 +421,7 @@ def map_blocks( preserved_indexes = template_indexes & set(input_indexes) new_indexes = template_indexes - set(input_indexes) indexes = {dim: input_indexes[dim] for dim in preserved_indexes} - indexes.update({k: template[k] for k in new_indexes}) + indexes.update({k: template.indexes[k] for k in new_indexes}) output_chunks = { dim: input_chunks[dim] for dim in template.dims if dim in input_chunks } @@ -433,7 +429,7 @@ def map_blocks( else: # template xarray object has been provided with proper sizes and chunk shapes indexes = input_indexes - indexes.update(get_index_vars(template)) + indexes.update(template.indexes) if isinstance(template, DataArray): output_chunks = dict(zip(template.dims, template.chunks)) # type: ignore else: From 04ffa6c4ba7fec3ec1c11e0a0b6a51e0a9b15545 Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 6 May 2020 13:36:15 -0600 Subject: [PATCH 12/23] Propagate attributes for index variables. --- xarray/core/parallel.py | 3 +++ xarray/tests/test_dask.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index c2b09c119ad..7ae5c624ea2 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -541,6 +541,9 @@ def map_blocks( hlg.layers[gname_l] = layer result = Dataset(coords=indexes, attrs=template.attrs) + for index in result.indexes: + result[index].attrs = template[index].attrs + for name, gname_l in var_key_map.items(): dims = template[name].dims var_chunks = [] diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 63234a5319e..ba51c034296 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -972,6 +972,7 @@ def make_da(): coords={"x": np.arange(10), "y": np.arange(100, 120)}, name="a", ).chunk({"x": 4, "y": 5}) + da.x.attrs["long_name"] = "x" da.attrs["test"] = "test" da.coords["c2"] = 0.5 da.coords["ndcoord"] = da.x * 2 @@ -995,6 +996,9 @@ def make_ds(): map_ds.attrs["test"] = "test" map_ds.coords["xx"] = map_ds["a"] * map_ds.y + map_ds.x.attrs["long_name"] = "x" + map_ds.y.attrs["long_name"] = "y" + return map_ds From ed1bbabe7f17a16c4df9032fadb8c7bf4e8785a9 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 9 May 2020 07:01:05 -0600 Subject: [PATCH 13/23] Propagate encoding for index variables. --- xarray/core/parallel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 7ae5c624ea2..4c96d72536e 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -543,6 +543,7 @@ def map_blocks( result = Dataset(coords=indexes, attrs=template.attrs) for index in result.indexes: result[index].attrs = template[index].attrs + result[index].encoding = template[index].encoding for name, gname_l in var_key_map.items(): dims = template[name].dims From d28ea7577d784e9c1bc5d59bdecccbf022751445 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 21 May 2020 10:30:00 -0600 Subject: [PATCH 14/23] Fix bug with reductions when template is provided. indexes should just have indexes for output variable. When template was provided, I was initializing to indexes to contain all input indexes. It should just have the indexes from template. Otherwise indexes for any indexed dimensions removed by func will still be propagated. --- xarray/core/parallel.py | 6 +++--- xarray/tests/test_dask.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 4c96d72536e..87a7d956ae4 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -428,8 +428,7 @@ def map_blocks( else: # template xarray object has been provided with proper sizes and chunk shapes - indexes = input_indexes - indexes.update(template.indexes) + indexes = dict(template.indexes) if isinstance(template, DataArray): output_chunks = dict(zip(template.dims, template.chunks)) # type: ignore else: @@ -498,7 +497,8 @@ def map_blocks( expected["coords"] = set(template.coords.keys()) # type: ignore expected["indexes"] = { dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)] - for dim in indexes + for dim in output_chunks + if dim in indexes } from_wrapper = (gname,) + chunk_tuple diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index ba51c034296..f2f38656490 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1132,6 +1132,22 @@ def sumda(da1, da2): mapped = xr.map_blocks(operator.add, da1, args=[da1.reindex(x=np.arange(20))]) xr.testing.assert_equal(da1 + da1, mapped) + # reduction + da1 = da1.chunk({"x": -1}) + da2 = da1 + 1 + with raise_if_dask_computes(): + mapped = xr.map_blocks(lambda a, b: (a + b).sum("x"), da1, args=[da2]) + xr.testing.assert_equal((da1 + da2).sum("x"), mapped) + + # reduction with template + da1 = da1.chunk({"x": -1}) + da2 = da1 + 1 + with raise_if_dask_computes(): + mapped = xr.map_blocks( + lambda a, b: (a + b).sum("x"), da1, args=[da2], template=da1.sum("x") + ) + xr.testing.assert_equal((da1 + da2).sum("x"), mapped) + @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_add_attrs(obj): From 4937bfc843a3b310442b5392abb34f251f3353b3 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 21 May 2020 10:52:50 -0600 Subject: [PATCH 15/23] more minimal fix. --- xarray/core/parallel.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 87a7d956ae4..a5ceb87f7da 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -497,8 +497,7 @@ def map_blocks( expected["coords"] = set(template.coords.keys()) # type: ignore expected["indexes"] = { dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)] - for dim in output_chunks - if dim in indexes + for dim in indexes } from_wrapper = (gname,) + chunk_tuple From 5b8cad6524478431c356236217dd430b0c10e23f Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 25 May 2020 17:18:22 -0600 Subject: [PATCH 16/23] minimize diff --- xarray/core/parallel.py | 206 +++++++++++++++++++++------------------- 1 file changed, 106 insertions(+), 100 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index a5ceb87f7da..fbe479f1490 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -74,57 +74,6 @@ def check_result_variables( ) -def subset_dataset_to_block( - graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index -): - """ - Creates a task that creates a subsets xarray dataset to a block determined by chunk_index; - whose extents are determined by input_chunk_bounds. - There are subtasks that create subsets of constituent variables. - - TODO: This is modifying graph in-place! - """ - - # this will become [[name1, variable1], - # [name2, variable2], - # ...] - # which is passed to dict and then to Dataset - data_vars = [] - coords = [] - - chunk_tuple = tuple(chunk_index.values()) - for name, variable in dataset.variables.items(): - # make a task that creates tuple of (dims, chunk) - if dask.is_dask_collection(variable.data): - # recursively index into dask_keys nested list to get chunk - chunk = variable.__dask_keys__() - for dim in variable.dims: - chunk = chunk[chunk_index[dim]] - - chunk_variable_task = (f"{gname}-{name}-{chunk[0]}",) + chunk_tuple - graph[chunk_variable_task] = (tuple, [variable.dims, chunk, variable.attrs]) - else: - # non-dask array with possibly chunked dimensions - # index into variable appropriately - subsetter = { - dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) - for dim in variable.dims - } - subset = variable.isel(subsetter) - chunk_variable_task = ( - "{}-{}".format(gname, dask.base.tokenize(subset)), - ) + chunk_tuple - graph[chunk_variable_task] = (tuple, [subset.dims, subset, subset.attrs]) - - # this task creates dict mapping variable name to above tuple - if name in dataset._coord_names: - coords.append([name, chunk_variable_task]) - else: - data_vars.append([name, chunk_variable_task]) - - return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) - - def dataset_to_dataarray(obj: Dataset) -> DataArray: if not isinstance(obj, Dataset): raise TypeError("Expected Dataset, got %s" % type(obj)) @@ -215,55 +164,6 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping return slice(None) -def _wrapper( - func: Callable, - args: List, - kwargs: dict, - arg_is_array: Iterable[bool], - expected: dict, -): - """ - Wrapper function that receives datasets in args; converts to dataarrays when necessary; - passes these to the user function `func` and checks returned objects for expected shapes/sizes/etc. - """ - - check_shapes = dict(args[0].dims) - check_shapes.update(expected["shapes"]) - - converted_args = [ - dataset_to_dataarray(arg) if is_array else arg - for is_array, arg in zip(arg_is_array, args) - ] - - result = func(*converted_args, **kwargs) - - # check all dims are present - missing_dimensions = set(expected["shapes"]) - set(result.sizes) - if missing_dimensions: - raise ValueError(f"Dimensions {missing_dimensions} missing on returned object.") - - # check that index lengths and values are as expected - for name, index in result.indexes.items(): - if name in check_shapes: - if len(index) != check_shapes[name]: - raise ValueError( - f"Received dimension {name!r} of length {len(index)}. Expected length {check_shapes[name]}." - ) - if name in expected["indexes"]: - expected_index = expected["indexes"][name] - if not index.equals(expected_index): - raise ValueError( - f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead." - ) - - # check that all expected variables were returned - check_result_variables(result, expected, "coords") - if isinstance(result, Dataset): - check_result_variables(result, expected, "data_vars") - - return make_dict(result) - - def map_blocks( func: Callable[..., T_DSorDA], obj: Union[DataArray, Dataset], @@ -371,6 +271,56 @@ def map_blocks( * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 """ + def _wrapper( + func: Callable, + args: List, + kwargs: dict, + arg_is_array: Iterable[bool], + expected: dict, + ): + """ + Wrapper function that receives datasets in args; converts to dataarrays when necessary; + passes these to the user function `func` and checks returned objects for expected shapes/sizes/etc. + """ + + check_shapes = dict(args[0].dims) + check_shapes.update(expected["shapes"]) + + converted_args = [ + dataset_to_dataarray(arg) if is_array else arg + for is_array, arg in zip(arg_is_array, args) + ] + + result = func(*converted_args, **kwargs) + + # check all dims are present + missing_dimensions = set(expected["shapes"]) - set(result.sizes) + if missing_dimensions: + raise ValueError( + f"Dimensions {missing_dimensions} missing on returned object." + ) + + # check that index lengths and values are as expected + for name, index in result.indexes.items(): + if name in check_shapes: + if len(index) != check_shapes[name]: + raise ValueError( + f"Received dimension {name!r} of length {len(index)}. Expected length {check_shapes[name]}." + ) + if name in expected["indexes"]: + expected_index = expected["indexes"][name] + if not index.equals(expected_index): + raise ValueError( + f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead." + ) + + # check that all expected variables were returned + check_result_variables(result, expected, "coords") + if isinstance(result, Dataset): + check_result_variables(result, expected, "data_vars") + + return make_dict(result) + if template is not None and not isinstance(template, (DataArray, Dataset)): raise TypeError( f"template must be a DataArray or Dataset. Received {type(template).__name__} instead." @@ -474,6 +424,62 @@ def map_blocks( dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items() } + def subset_dataset_to_block( + graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index + ): + """ + Creates a task that creates a subsets xarray dataset to a block determined by chunk_index; + whose extents are determined by input_chunk_bounds. + There are subtasks that create subsets of constituent variables. + + TODO: This is modifying graph in-place! + """ + + # this will become [[name1, variable1], + # [name2, variable2], + # ...] + # which is passed to dict and then to Dataset + data_vars = [] + coords = [] + + chunk_tuple = tuple(chunk_index.values()) + for name, variable in dataset.variables.items(): + # make a task that creates tuple of (dims, chunk) + if dask.is_dask_collection(variable.data): + # recursively index into dask_keys nested list to get chunk + chunk = variable.__dask_keys__() + for dim in variable.dims: + chunk = chunk[chunk_index[dim]] + + chunk_variable_task = (f"{gname}-{name}-{chunk[0]}",) + chunk_tuple + graph[chunk_variable_task] = ( + tuple, + [variable.dims, chunk, variable.attrs], + ) + else: + # non-dask array with possibly chunked dimensions + # index into variable appropriately + subsetter = { + dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) + for dim in variable.dims + } + subset = variable.isel(subsetter) + chunk_variable_task = ( + "{}-{}".format(gname, dask.base.tokenize(subset)), + ) + chunk_tuple + graph[chunk_variable_task] = ( + tuple, + [subset.dims, subset, subset.attrs], + ) + + # this task creates dict mapping variable name to above tuple + if name in dataset._coord_names: + coords.append([name, chunk_variable_task]) + else: + data_vars.append([name, chunk_variable_task]) + + return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) + # iterate over all possible chunk combinations for chunk_tuple in itertools.product(*ichunk.values()): # mapping from dimension name to chunk index From 790868075509fd65e81326d814f4e8b5f77dead6 Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 25 May 2020 17:23:57 -0600 Subject: [PATCH 17/23] Update docs. --- doc/whats-new.rst | 2 ++ xarray/core/dataarray.py | 5 ++--- xarray/core/dataset.py | 5 ++--- xarray/core/parallel.py | 10 +++++----- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b22a7217568..1cf9780492e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -70,6 +70,8 @@ New Features - :py:meth:`map_blocks` now accepts a ``template`` kwarg. This allows use cases where the result of a computation could not be inferred automatically. By `Deepak Cherian `_ +- :py:meth:`map_blocks` can now handle dask-backed xarray objects in ``args``. (:pull:`3818`) + By `Deepak Cherian `_ Bug fixes ~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 236938bac74..24600344ebb 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3270,9 +3270,8 @@ def map_blocks( This function cannot add a new chunked dimension. args: Sequence - Passed verbatim to func after unpacking, after the sliced DataArray. xarray - objects, if any, will not be split by chunks. Passing dask collections is - not allowed. + Passed verbatim to func after unpacking, after the sliced obj. + Any xarray objects will also be split by blocks and then passed on. kwargs: Mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be split by chunks. Passing dask collections is not allowed. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 3a55f3eca27..6ee5a70fd67 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5729,9 +5729,8 @@ def map_blocks( This function cannot add a new chunked dimension. args: Sequence - Passed verbatim to func after unpacking, after the sliced DataArray. xarray - objects, if any, will not be split by chunks. Passing dask collections is - not allowed. + Passed verbatim to func after unpacking, after the sliced obj. + Any xarray objects will also be split by blocks and then passed on. kwargs: Mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be split by chunks. Passing dask collections is not allowed. diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index fbe479f1490..fc20827c8a1 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -192,7 +192,7 @@ def map_blocks( Passed to the function as its first argument, one dask chunk at a time. args: Sequence Passed verbatim to func after unpacking, after the sliced obj. - Any xarray objects will also be split by chunks and then passed on. + Any xarray objects will also be split by blocks and then passed on. kwargs: Mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be split by chunks. Passing dask collections is not allowed. @@ -431,8 +431,6 @@ def subset_dataset_to_block( Creates a task that creates a subsets xarray dataset to a block determined by chunk_index; whose extents are determined by input_chunk_bounds. There are subtasks that create subsets of constituent variables. - - TODO: This is modifying graph in-place! """ # this will become [[name1, variable1], @@ -457,7 +455,7 @@ def subset_dataset_to_block( [variable.dims, chunk, variable.attrs], ) else: - # non-dask array with possibly chunked dimensions + # non-dask array possibly with dimensions chunked on other variables # index into variable appropriately subsetter = { dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) @@ -492,7 +490,8 @@ def subset_dataset_to_block( for isxr, arg in zip(is_xarray, npargs) ] - # expected["shapes", "coords", "data_vars", "indexes"] are used to raise nice error messages in _wrapper + # expected["shapes", "coords", "data_vars", "indexes"] are used to + # raise nice error messages in _wrapper expected = {} # input chunk 0 along a dimension maps to output chunk 0 along the same dimension # even if length of dimension is changed by the applied function @@ -566,6 +565,7 @@ def subset_dataset_to_block( hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype ) result[name] = (dims, data, template[name].attrs) + result[name].encoding = template[name].encoding result = result.set_coords(template._coord_names) From 2bdcc642afecd09579bf23d9f7c40c0ebf6f73d5 Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 27 May 2020 10:29:39 -0600 Subject: [PATCH 18/23] Address joe comments. --- xarray/core/parallel.py | 28 ++++++++++++++-------------- xarray/tests/test_dask.py | 5 ++--- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index fc20827c8a1..83f30a2895f 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -35,9 +35,9 @@ def to_object_array(iterable): + # using empty_like calls compute npargs = np.empty((len(iterable),), dtype=np.object) - for idx, item in enumerate(iterable): - npargs[idx] = item + npargs[:] = iterable return npargs @@ -180,9 +180,9 @@ def map_blocks( ---------- func: callable User-provided function that accepts a DataArray or Dataset as its first - parameter. The function will receive a subset of 'obj' (see below), + parameter. The function will receive a subset or 'block' of 'obj' (see below), corresponding to one chunk along each chunked dimension. ``func`` will be - executed as ``func(obj_subset, *args, **kwargs)``. + executed as ``func(subset_obj, *subset_args, **kwargs)``. This function must return either a single DataArray or a single Dataset. @@ -191,11 +191,12 @@ def map_blocks( obj: DataArray, Dataset Passed to the function as its first argument, one dask chunk at a time. args: Sequence - Passed verbatim to func after unpacking, after the sliced obj. + Passed verbatim to func after unpacking. Any xarray objects will also be split by blocks and then passed on. + xarray objects in args must be aligned with obj, otherwise an error is raised. kwargs: Mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be - split by chunks. Passing dask collections is not allowed. + split by chunks. Passing dask collections in kwargs is not allowed. template: (optional) DataArray, Dataset xarray object representing the final result after compute is called. If not provided, the function will be first run on mocked-up data, that looks like 'obj' but @@ -346,9 +347,8 @@ def _wrapper( is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in npargs] is_array = [isinstance(arg, DataArray) for arg in npargs] - # align all xarray objects - # TODO: should we allow join as a kwarg or force everything to be aligned to the first object? - aligned = align(*npargs[is_xarray], join="left") + # all xarray objects must be aligned. This is consistent with apply_ufunc. + aligned = align(*npargs[is_xarray], join="exact") # assigning to object arrays works better when RHS is object array # https://stackoverflow.com/questions/43645135/boolean-indexing-assignment-of-a-numpy-array-to-a-numpy-array npargs[is_xarray] = to_object_array(aligned) @@ -428,14 +428,14 @@ def subset_dataset_to_block( graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index ): """ - Creates a task that creates a subsets xarray dataset to a block determined by chunk_index; - whose extents are determined by input_chunk_bounds. - There are subtasks that create subsets of constituent variables. + Creates a task that subsets an xarray dataset to a block determined by chunk_index. + Block extents are determined by input_chunk_bounds. + Also subtasks that subset the constituent variables of a dataset. """ # this will become [[name1, variable1], - # [name2, variable2], - # ...] + # [name2, variable2], + # ...] # which is passed to dict and then to Dataset data_vars = [] coords = [] diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index f2f38656490..eb06336d296 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1128,9 +1128,8 @@ def sumda(da1, da2): with raises_regex(ValueError, "Chunk sizes along dimension 'x'"): xr.map_blocks(operator.add, da1, args=[da1.chunk({"x": 1})]) - with raise_if_dask_computes(): - mapped = xr.map_blocks(operator.add, da1, args=[da1.reindex(x=np.arange(20))]) - xr.testing.assert_equal(da1 + da1, mapped) + with raises_regex(ValueError, "indexes along dimension 'x' are not equal"): + xr.map_blocks(operator.add, da1, args=[da1.reindex(x=np.arange(20))]) # reduction da1 = da1.chunk({"x": -1}) From 552571cc2acd506bfff2ecc85549510d0c6b9587 Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 27 May 2020 10:35:00 -0600 Subject: [PATCH 19/23] docstring updates. --- xarray/core/dataarray.py | 85 ++++++++++++++++++++++++++++++--------- xarray/core/dataset.py | 86 +++++++++++++++++++++++++++++++--------- xarray/core/parallel.py | 31 +++++++-------- 3 files changed, 149 insertions(+), 53 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 24600344ebb..00b73d2e33c 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3262,44 +3262,93 @@ def map_blocks( ---------- func: callable User-provided function that accepts a DataArray as its first - parameter. The function will receive a subset, i.e. one block, of this DataArray - (see below), corresponding to one chunk along each chunked dimension. ``func`` will be - executed as ``func(block_subset, *args, **kwargs)``. + parameter. The function will receive a subset or 'block' of this DataArray (see below), + corresponding to one chunk along each chunked dimension. ``func`` will be + executed as ``func(subset_dataarray, *subset_args, **kwargs)``. This function must return either a single DataArray or a single Dataset. This function cannot add a new chunked dimension. + + obj: DataArray, Dataset + Passed to the function as its first argument, one block at a time. args: Sequence - Passed verbatim to func after unpacking, after the sliced obj. - Any xarray objects will also be split by blocks and then passed on. + Passed to func after unpacking and subsetting any xarray objects by blocks. + xarray objects in args must be aligned with obj, otherwise an error is raised. kwargs: Mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be - split by chunks. Passing dask collections is not allowed. + subset to blocks. Passing dask collections in kwargs is not allowed. template: (optional) DataArray, Dataset xarray object representing the final result after compute is called. If not provided, - the function will be first run on mocked-up data, that looks like 'obj' but + the function will be first run on mocked-up data, that looks like ``obj`` but has sizes 0, to determine properties of the returned object such as dtype, - variable names, new dimensions and new indexes (if any). - 'template' must be provided if the function changes the size of existing dimensions. + variable names, attributes, new dimensions and new indexes (if any). + ``template`` must be provided if the function changes the size of existing dimensions. + When provided, ``attrs`` on variables in `template` are copied over to the result. Any + ``attrs`` set by ``func`` will be ignored. + Returns ------- - A single DataArray or Dataset with dask backend, reassembled from the outputs of - the function. + A single DataArray or Dataset with dask backend, reassembled from the outputs of the + function. Notes ----- - This method is designed for when one needs to manipulate a whole xarray object - within each chunk. In the more common case where one can work on numpy arrays, - it is recommended to use apply_ufunc. + This function is designed for when ``func`` needs to manipulate a whole xarray object + subset to each block. In the more common case where ``func`` can work on numpy arrays, it is + recommended to use ``apply_ufunc``. - If none of the variables in this DataArray is backed by dask, calling this - method is equivalent to calling ``func(self, *args, **kwargs)``. + If none of the variables in ``obj`` is backed by dask arrays, calling this function is + equivalent to calling ``func(obj, *args, **kwargs)``. See Also -------- - dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks, - xarray.Dataset.map_blocks + dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks, + xarray.DataArray.map_blocks + + Examples + -------- + + Calculate an anomaly from climatology using ``.groupby()``. Using + ``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``, + its indices, and its methods like ``.groupby()``. + + >>> def calculate_anomaly(da, groupby_type="time.month"): + ... # Necessary workaround to xarray's check with zero dimensions + ... # https://github.com/pydata/xarray/issues/3575 + ... gb = da.groupby(groupby_type) + ... clim = gb.mean(dim="time") + ... return gb - clim + >>> time = xr.cftime_range("1990-01", "1992-01", freq="M") + >>> np.random.seed(123) + >>> array = xr.DataArray( + ... np.random.rand(len(time)), dims="time", coords=[time] + ... ).chunk() + >>> array.map_blocks(calculate_anomaly, template=array).compute() + + array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862, + 0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714, + -0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 , + 0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108, + 0.07673453, 0.22865714, 0.19063865, -0.0590131 ]) + Coordinates: + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + + Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments + to the function being applied in ``xr.map_blocks()``: + + >>> array.map_blocks( + ... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=array, + ... ) + + array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 , + -0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425, + -0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273, + 0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 , + 0.14482397, 0.35985481, 0.23487834, 0.12144652]) + Coordinates: + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 """ from .parallel import map_blocks diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 6ee5a70fd67..d083a551a60 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5717,48 +5717,98 @@ def map_blocks( .. warning:: This method is experimental and its signature may change. - Parameters + Parameters ---------- func: callable User-provided function that accepts a Dataset as its first - parameter. The function will receive a subset, i.e. one block, of this Dataset - (see below), corresponding to one chunk along each chunked dimension. ``func`` will be - executed as ``func(block_subset, *args, **kwargs)``. + parameter. The function will receive a subset or 'block' of this Dataset (see below), + corresponding to one chunk along each chunked dimension. ``func`` will be + executed as ``func(subset_dataset, *subset_args, **kwargs)``. This function must return either a single DataArray or a single Dataset. This function cannot add a new chunked dimension. + + obj: DataArray, Dataset + Passed to the function as its first argument, one block at a time. args: Sequence - Passed verbatim to func after unpacking, after the sliced obj. - Any xarray objects will also be split by blocks and then passed on. + Passed to func after unpacking and subsetting any xarray objects by blocks. + xarray objects in args must be aligned with obj, otherwise an error is raised. kwargs: Mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be - split by chunks. Passing dask collections is not allowed. + subset to blocks. Passing dask collections in kwargs is not allowed. template: (optional) DataArray, Dataset xarray object representing the final result after compute is called. If not provided, - the function will be first run on mocked-up data, that looks like 'obj' but + the function will be first run on mocked-up data, that looks like ``obj`` but has sizes 0, to determine properties of the returned object such as dtype, - variable names, new dimensions and new indexes (if any). - 'template' must be provided if the function changes the size of existing dimensions. + variable names, attributes, new dimensions and new indexes (if any). + ``template`` must be provided if the function changes the size of existing dimensions. + When provided, ``attrs`` on variables in `template` are copied over to the result. Any + ``attrs`` set by ``func`` will be ignored. + Returns ------- - A single DataArray or Dataset with dask backend, reassembled from the outputs of - the function. + A single DataArray or Dataset with dask backend, reassembled from the outputs of the + function. Notes ----- - This method is designed for when one needs to manipulate a whole xarray object - within each chunk. In the more common case where one can work on numpy arrays, - it is recommended to use apply_ufunc. + This function is designed for when ``func`` needs to manipulate a whole xarray object + subset to each block. In the more common case where ``func`` can work on numpy arrays, it is + recommended to use ``apply_ufunc``. - If none of the variables in this Dataset is backed by dask, calling this method - is equivalent to calling ``func(self, *args, **kwargs)``. + If none of the variables in ``obj`` is backed by dask arrays, calling this function is + equivalent to calling ``func(obj, *args, **kwargs)``. See Also -------- - dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks, + dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks, xarray.DataArray.map_blocks + + Examples + -------- + + Calculate an anomaly from climatology using ``.groupby()``. Using + ``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``, + its indices, and its methods like ``.groupby()``. + + >>> def calculate_anomaly(da, groupby_type="time.month"): + ... # Necessary workaround to xarray's check with zero dimensions + ... # https://github.com/pydata/xarray/issues/3575 + ... gb = da.groupby(groupby_type) + ... clim = gb.mean(dim="time") + ... return gb - clim + >>> time = xr.cftime_range("1990-01", "1992-01", freq="M") + >>> np.random.seed(123) + >>> array = xr.DataArray( + ... np.random.rand(len(time)), dims="time", coords=[time] + ... ).chunk() + >>> ds = xr.Dataset({"a": array}) + >>> ds.map_blocks(calculate_anomaly, template=ds).compute() + + array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862, + 0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714, + -0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 , + 0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108, + 0.07673453, 0.22865714, 0.19063865, -0.0590131 ]) + Coordinates: + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + + Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments + to the function being applied in ``xr.map_blocks()``: + + >>> ds.map_blocks( + ... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=ds, + ... ) + + array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 , + -0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425, + -0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273, + 0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 , + 0.14482397, 0.35985481, 0.23487834, 0.12144652]) + Coordinates: + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 """ from .parallel import map_blocks diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 83f30a2895f..577ab36a292 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -180,7 +180,7 @@ def map_blocks( ---------- func: callable User-provided function that accepts a DataArray or Dataset as its first - parameter. The function will receive a subset or 'block' of 'obj' (see below), + parameter ``obj``. The function will receive a subset or 'block' of ``obj`` (see below), corresponding to one chunk along each chunked dimension. ``func`` will be executed as ``func(subset_obj, *subset_args, **kwargs)``. @@ -189,22 +189,21 @@ def map_blocks( This function cannot add a new chunked dimension. obj: DataArray, Dataset - Passed to the function as its first argument, one dask chunk at a time. + Passed to the function as its first argument, one block at a time. args: Sequence - Passed verbatim to func after unpacking. - Any xarray objects will also be split by blocks and then passed on. + Passed to func after unpacking and subsetting any xarray objects by blocks. xarray objects in args must be aligned with obj, otherwise an error is raised. kwargs: Mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be - split by chunks. Passing dask collections in kwargs is not allowed. + subset to blocks. Passing dask collections in kwargs is not allowed. template: (optional) DataArray, Dataset xarray object representing the final result after compute is called. If not provided, - the function will be first run on mocked-up data, that looks like 'obj' but + the function will be first run on mocked-up data, that looks like ``obj`` but has sizes 0, to determine properties of the returned object such as dtype, variable names, attributes, new dimensions and new indexes (if any). - 'template' must be provided if the function changes the size of existing dimensions. - When provided, `attrs` on variables in `template` are copied over to the result. Any - `attrs` set by `func` will be ignored. + ``template`` must be provided if the function changes the size of existing dimensions. + When provided, ``attrs`` on variables in `template` are copied over to the result. Any + ``attrs`` set by ``func`` will be ignored. Returns @@ -214,11 +213,11 @@ def map_blocks( Notes ----- - This function is designed for when one needs to manipulate a whole xarray object - within each chunk. In the more common case where one can work on numpy arrays, it is - recommended to use apply_ufunc. + This function is designed for when ``func`` needs to manipulate a whole xarray object + subset to each block. In the more common case where ``func`` can work on numpy arrays, it is + recommended to use ``apply_ufunc``. - If none of the variables in obj is backed by dask, calling this function is + If none of the variables in ``obj`` is backed by dask arrays, calling this function is equivalent to calling ``func(obj, *args, **kwargs)``. See Also @@ -236,8 +235,6 @@ def map_blocks( >>> def calculate_anomaly(da, groupby_type="time.month"): ... # Necessary workaround to xarray's check with zero dimensions ... # https://github.com/pydata/xarray/issues/3575 - ... if sum(da.shape) == 0: - ... return da ... gb = da.groupby(groupby_type) ... clim = gb.mean(dim="time") ... return gb - clim @@ -246,7 +243,7 @@ def map_blocks( >>> array = xr.DataArray( ... np.random.rand(len(time)), dims="time", coords=[time] ... ).chunk() - >>> xr.map_blocks(calculate_anomaly, array).compute() + >>> xr.map_blocks(calculate_anomaly, array, template=array).compute() array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862, 0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714, @@ -260,7 +257,7 @@ def map_blocks( to the function being applied in ``xr.map_blocks()``: >>> xr.map_blocks( - ... calculate_anomaly, array, kwargs={"groupby_type": "time.year"}, + ... calculate_anomaly, array, kwargs={"groupby_type": "time.year"}, template=array, ... ) array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 , From 10427bba9f68f5849703c71c61d687ce819288b9 Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 27 May 2020 10:41:07 -0600 Subject: [PATCH 20/23] minor docstring change --- xarray/core/dataarray.py | 2 -- xarray/core/dataset.py | 2 -- xarray/core/parallel.py | 2 -- 3 files changed, 6 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 00b73d2e33c..3451ff14c8f 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3315,8 +3315,6 @@ def map_blocks( its indices, and its methods like ``.groupby()``. >>> def calculate_anomaly(da, groupby_type="time.month"): - ... # Necessary workaround to xarray's check with zero dimensions - ... # https://github.com/pydata/xarray/issues/3575 ... gb = da.groupby(groupby_type) ... clim = gb.mean(dim="time") ... return gb - clim diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d083a551a60..63e27a8e007 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5774,8 +5774,6 @@ def map_blocks( its indices, and its methods like ``.groupby()``. >>> def calculate_anomaly(da, groupby_type="time.month"): - ... # Necessary workaround to xarray's check with zero dimensions - ... # https://github.com/pydata/xarray/issues/3575 ... gb = da.groupby(groupby_type) ... clim = gb.mean(dim="time") ... return gb - clim diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 577ab36a292..86fff838162 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -233,8 +233,6 @@ def map_blocks( its indices, and its methods like ``.groupby()``. >>> def calculate_anomaly(da, groupby_type="time.month"): - ... # Necessary workaround to xarray's check with zero dimensions - ... # https://github.com/pydata/xarray/issues/3575 ... gb = da.groupby(groupby_type) ... clim = gb.mean(dim="time") ... return gb - clim From ba522e07b6e18587cb6ae069715daeca6439efbc Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 27 May 2020 10:43:56 -0600 Subject: [PATCH 21/23] minor. --- xarray/core/parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 86fff838162..a8dedc45534 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -430,7 +430,7 @@ def subset_dataset_to_block( # this will become [[name1, variable1], # [name2, variable2], - # ...] + # ...] # which is passed to dict and then to Dataset data_vars = [] coords = [] From db9fa9f29e603ae3dc94366de691c1bfd1f2d34f Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 27 May 2020 10:47:02 -0600 Subject: [PATCH 22/23] remove useless check_shapes variable. --- xarray/core/parallel.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index a8dedc45534..522c5b36ff5 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -279,9 +279,6 @@ def _wrapper( passes these to the user function `func` and checks returned objects for expected shapes/sizes/etc. """ - check_shapes = dict(args[0].dims) - check_shapes.update(expected["shapes"]) - converted_args = [ dataset_to_dataarray(arg) if is_array else arg for is_array, arg in zip(arg_is_array, args) @@ -298,10 +295,10 @@ def _wrapper( # check that index lengths and values are as expected for name, index in result.indexes.items(): - if name in check_shapes: - if len(index) != check_shapes[name]: + if name in expected["shapes"]: + if len(index) != expected["shapes"][name]: raise ValueError( - f"Received dimension {name!r} of length {len(index)}. Expected length {check_shapes[name]}." + f"Received dimension {name!r} of length {len(index)}. Expected length {expected['shapes'][name]}." ) if name in expected["indexes"]: expected_index = expected["indexes"][name] From 5644c652a70a59d2dc27d441dbf0b17dfcb0860a Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 29 May 2020 11:43:48 -0600 Subject: [PATCH 23/23] fix docstring --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 63e27a8e007..29cecae55b0 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5717,7 +5717,7 @@ def map_blocks( .. warning:: This method is experimental and its signature may change. - Parameters + Parameters ---------- func: callable User-provided function that accepts a Dataset as its first